diff --git a/benchmarks/ops/bench_logical_reduce.py b/benchmarks/ops/bench_logical_reduce.py index fcd03cbfe..fe98f3129 100644 --- a/benchmarks/ops/bench_logical_reduce.py +++ b/benchmarks/ops/bench_logical_reduce.py @@ -32,7 +32,7 @@ def test_any_bench(shape: tuple, dtype: torch.dtype) -> None: test = AnyTest(shape, dtype) inputs = test.gen_inputs() - op = AnyFwdOp(dtype=dtype) + op = AnyFwdOp(dtype=dtype, dim=-1) bm = ManifestBenchmark(_ANY_OP, op, test) try: result = bm.profile(op, *inputs) @@ -59,7 +59,7 @@ def test_all_bench(shape: tuple, dtype: torch.dtype) -> None: test = AllTest(shape, dtype) inputs = test.gen_inputs() - op = AllFwdOp(dtype=dtype) + op = AllFwdOp(dtype=dtype, dim=-1) bm = ManifestBenchmark(_ALL_OP, op, test) try: result = bm.profile(op, *inputs) @@ -86,7 +86,7 @@ def test_count_nonzero_bench(shape: tuple, dtype: torch.dtype) -> None: test = CountNonzeroTest(shape, dtype) inputs = test.gen_inputs() - op = CountNonzeroFwdOp(dtype=dtype) + op = CountNonzeroFwdOp(dtype=dtype, dim=-1) bm = ManifestBenchmark(_COUNT_NONZERO_OP, op, test) try: result = bm.profile(op, *inputs) diff --git a/benchmarks/ops/bench_reduce.py b/benchmarks/ops/bench_reduce.py index 77ff3a075..ddfb9f70b 100644 --- a/benchmarks/ops/bench_reduce.py +++ b/benchmarks/ops/bench_reduce.py @@ -58,6 +58,7 @@ def test_sum_bench( test = SumTest(shape, dtype) inputs = test.gen_inputs() + op_params.setdefault("dim", -1) # baseline below reduces dim=-1 op = SumFwdOp(dtype=dtype, **op_params) bm = ManifestBenchmark(_SUM_OP, op, test) try: @@ -88,7 +89,7 @@ def test_mean_bench(shape: tuple, dtype: torch.dtype) -> None: test = MeanTest(shape, dtype) inputs = test.gen_inputs() - op = MeanFwdOp(dtype=dtype) + op = MeanFwdOp(dtype=dtype, dim=-1) bm = ManifestBenchmark(_MEAN_OP, op, test) try: result = bm.profile(op, *inputs) @@ -115,7 +116,7 @@ def test_amax_bench(shape: tuple, dtype: torch.dtype) -> None: test = AmaxTest(shape, dtype) inputs = test.gen_inputs() - op = AmaxFwdOp(dtype=dtype) + op = AmaxFwdOp(dtype=dtype, dim=-1) bm = ManifestBenchmark(_AMAX_OP, op, test) try: result = bm.profile(op, *inputs) @@ -142,7 +143,7 @@ def test_amin_bench(shape: tuple, dtype: torch.dtype) -> None: test = AminTest(shape, dtype) inputs = test.gen_inputs() - op = AminFwdOp(dtype=dtype) + op = AminFwdOp(dtype=dtype, dim=-1) bm = ManifestBenchmark(_AMIN_OP, op, test) try: result = bm.profile(op, *inputs) @@ -196,7 +197,7 @@ def test_std_bench(shape: tuple, dtype: torch.dtype) -> None: test = StdTest(shape, dtype) inputs = test.gen_inputs() - op = StdFwdOp(dtype=dtype, correction=1) + op = StdFwdOp(dtype=dtype, dim=-1, correction=1) bm = ManifestBenchmark(_STD_OP, op, test) try: result = bm.profile(op, *inputs) @@ -223,7 +224,7 @@ def test_var_bench(shape: tuple, dtype: torch.dtype) -> None: test = VarTest(shape, dtype) inputs = test.gen_inputs() - op = VarFwdOp(dtype=dtype, correction=1) + op = VarFwdOp(dtype=dtype, dim=-1, correction=1) bm = ManifestBenchmark(_VAR_OP, op, test) try: result = bm.profile(op, *inputs) @@ -250,7 +251,7 @@ def test_var_mean_bench(shape: tuple, dtype: torch.dtype) -> None: test = VarMeanTest(shape, dtype) inputs = test.gen_inputs() - op = VarMeanFwdOp(dtype=dtype, correction=1) + op = VarMeanFwdOp(dtype=dtype, dim=-1, correction=1) bm = ManifestBenchmark(_VAR_MEAN_OP, op, test) try: result = bm.profile(op, *inputs) diff --git a/tests/ops/test_logical_reduce.py b/tests/ops/test_logical_reduce.py index 0f789cf28..d88ce2dd1 100644 --- a/tests/ops/test_logical_reduce.py +++ b/tests/ops/test_logical_reduce.py @@ -225,7 +225,7 @@ def test_any_op(m: int, n: int, dtype: torch.dtype) -> None: from tileops.ops.reduction.any_op import AnyFwdOp test = LogicalReduceTest(m, n, dtype, "any") - op = AnyFwdOp(dtype=dtype) + op = AnyFwdOp(dtype=dtype, dim=-1) test.check(op, *test.gen_inputs(), compare=_exact_compare) @@ -235,7 +235,7 @@ def test_any_non_contiguous(m: int, n: int, dtype: torch.dtype) -> None: x_full = _make_noncontig_input(m, n, dtype) x = x_full[:, :n] - op = AnyFwdOp(dtype=dtype) + op = AnyFwdOp(dtype=dtype, dim=-1) ref = x.contiguous().bool().any(dim=-1) y = op(x) assert y.dtype == torch.bool @@ -247,7 +247,7 @@ def test_any_3d(batch: int, seq: int, hidden: int, dtype: torch.dtype) -> None: from tileops.ops.reduction.any_op import AnyFwdOp x = torch.randn(batch, seq, hidden, dtype=dtype, device="cuda") - op = AnyFwdOp(dtype=dtype) + op = AnyFwdOp(dtype=dtype, dim=-1) ref = x.bool().any(dim=-1) y = op(x) assert y.dtype == torch.bool @@ -259,7 +259,7 @@ def test_any_4d(b0: int, b1: int, b2: int, n: int, dtype: torch.dtype) -> None: from tileops.ops.reduction.any_op import AnyFwdOp x = torch.randn(b0, b1, b2, n, dtype=dtype, device="cuda") - op = AnyFwdOp(dtype=dtype) + op = AnyFwdOp(dtype=dtype, dim=-1) ref = x.bool().any(dim=-1) y = op(x) assert y.dtype == torch.bool @@ -271,7 +271,7 @@ def test_any_1d(n: int, dtype: torch.dtype) -> None: from tileops.ops.reduction.any_op import AnyFwdOp x = _make_1d_input(n, dtype) - op = AnyFwdOp(dtype=dtype) + op = AnyFwdOp(dtype=dtype, dim=-1) ref = x.bool().any(dim=-1) y = op(x) assert y.dtype == torch.bool @@ -314,7 +314,7 @@ def test_all_op(m: int, n: int, dtype: torch.dtype) -> None: from tileops.ops.reduction.all_op import AllFwdOp test = LogicalReduceTest(m, n, dtype, "all") - op = AllFwdOp(dtype=dtype) + op = AllFwdOp(dtype=dtype, dim=-1) test.check(op, *test.gen_inputs(), compare=_exact_compare) @@ -324,7 +324,7 @@ def test_all_non_contiguous(m: int, n: int, dtype: torch.dtype) -> None: x_full = _make_noncontig_input(m, n, dtype) x = x_full[:, :n] - op = AllFwdOp(dtype=dtype) + op = AllFwdOp(dtype=dtype, dim=-1) ref = x.contiguous().bool().all(dim=-1) y = op(x) assert y.dtype == torch.bool @@ -336,7 +336,7 @@ def test_all_3d(batch: int, seq: int, hidden: int, dtype: torch.dtype) -> None: from tileops.ops.reduction.all_op import AllFwdOp x = torch.randn(batch, seq, hidden, dtype=dtype, device="cuda") - op = AllFwdOp(dtype=dtype) + op = AllFwdOp(dtype=dtype, dim=-1) ref = x.bool().all(dim=-1) y = op(x) assert y.dtype == torch.bool @@ -348,7 +348,7 @@ def test_all_4d(b0: int, b1: int, b2: int, n: int, dtype: torch.dtype) -> None: from tileops.ops.reduction.all_op import AllFwdOp x = torch.randn(b0, b1, b2, n, dtype=dtype, device="cuda") - op = AllFwdOp(dtype=dtype) + op = AllFwdOp(dtype=dtype, dim=-1) ref = x.bool().all(dim=-1) y = op(x) assert y.dtype == torch.bool @@ -360,7 +360,7 @@ def test_all_1d(n: int, dtype: torch.dtype) -> None: from tileops.ops.reduction.all_op import AllFwdOp x = _make_1d_input(n, dtype) - op = AllFwdOp(dtype=dtype) + op = AllFwdOp(dtype=dtype, dim=-1) ref = x.bool().all(dim=-1) y = op(x) assert y.dtype == torch.bool @@ -403,7 +403,7 @@ def test_count_nonzero_op(m: int, n: int, dtype: torch.dtype) -> None: from tileops.ops.reduction.count_nonzero import CountNonzeroFwdOp test = LogicalReduceTest(m, n, dtype, "count_nonzero") - op = CountNonzeroFwdOp(dtype=dtype) + op = CountNonzeroFwdOp(dtype=dtype, dim=-1) test.check(op, *test.gen_inputs(), compare=_exact_compare_int64) @@ -413,7 +413,7 @@ def test_count_nonzero_non_contiguous(m: int, n: int, dtype: torch.dtype) -> Non x_full = _make_noncontig_input(m, n, dtype) x = x_full[:, :n] - op = CountNonzeroFwdOp(dtype=dtype) + op = CountNonzeroFwdOp(dtype=dtype, dim=-1) ref = torch.count_nonzero(x.contiguous(), dim=-1).to(torch.int64) y = op(x) assert y.dtype == torch.int64 @@ -425,7 +425,7 @@ def test_count_nonzero_3d(batch: int, seq: int, hidden: int, dtype: torch.dtype) from tileops.ops.reduction.count_nonzero import CountNonzeroFwdOp x = torch.randn(batch, seq, hidden, dtype=dtype, device="cuda") - op = CountNonzeroFwdOp(dtype=dtype) + op = CountNonzeroFwdOp(dtype=dtype, dim=-1) ref = torch.count_nonzero(x, dim=-1).to(torch.int64) y = op(x) assert y.dtype == torch.int64 @@ -437,7 +437,7 @@ def test_count_nonzero_4d(b0: int, b1: int, b2: int, n: int, dtype: torch.dtype) from tileops.ops.reduction.count_nonzero import CountNonzeroFwdOp x = torch.randn(b0, b1, b2, n, dtype=dtype, device="cuda") - op = CountNonzeroFwdOp(dtype=dtype) + op = CountNonzeroFwdOp(dtype=dtype, dim=-1) ref = torch.count_nonzero(x, dim=-1).to(torch.int64) y = op(x) assert y.dtype == torch.int64 @@ -449,7 +449,7 @@ def test_count_nonzero_1d(n: int, dtype: torch.dtype) -> None: from tileops.ops.reduction.count_nonzero import CountNonzeroFwdOp x = _make_1d_input(n, dtype) - op = CountNonzeroFwdOp(dtype=dtype) + op = CountNonzeroFwdOp(dtype=dtype, dim=-1) ref = torch.count_nonzero(x, dim=-1).to(torch.int64) y = op(x) assert y.dtype == torch.int64 @@ -508,7 +508,7 @@ def test_any_smoke_float16(m: int, n: int, dtype: torch.dtype) -> None: from tileops.ops.reduction.any_op import AnyFwdOp test = LogicalReduceTest(m, n, dtype, "any") - op = AnyFwdOp(dtype=dtype) + op = AnyFwdOp(dtype=dtype, dim=-1) test.check(op, *test.gen_inputs(), compare=_exact_compare) @@ -517,7 +517,7 @@ def test_any_smoke_bfloat16(m: int, n: int, dtype: torch.dtype) -> None: from tileops.ops.reduction.any_op import AnyFwdOp test = LogicalReduceTest(m, n, dtype, "any") - op = AnyFwdOp(dtype=dtype) + op = AnyFwdOp(dtype=dtype, dim=-1) test.check(op, *test.gen_inputs(), compare=_exact_compare) @@ -526,7 +526,7 @@ def test_any_smoke_int32(m: int, n: int, dtype: torch.dtype) -> None: from tileops.ops.reduction.any_op import AnyFwdOp test = LogicalReduceTest(m, n, dtype, "any") - op = AnyFwdOp(dtype=dtype) + op = AnyFwdOp(dtype=dtype, dim=-1) test.check(op, *test.gen_inputs(), compare=_exact_compare) @@ -535,7 +535,7 @@ def test_any_smoke_int64(m: int, n: int, dtype: torch.dtype) -> None: from tileops.ops.reduction.any_op import AnyFwdOp test = LogicalReduceTest(m, n, dtype, "any") - op = AnyFwdOp(dtype=dtype) + op = AnyFwdOp(dtype=dtype, dim=-1) test.check(op, *test.gen_inputs(), compare=_exact_compare) @@ -544,7 +544,7 @@ def test_any_smoke_bool(m: int, n: int, dtype: torch.dtype) -> None: from tileops.ops.reduction.any_op import AnyFwdOp test = LogicalReduceTest(m, n, dtype, "any") - op = AnyFwdOp(dtype=dtype) + op = AnyFwdOp(dtype=dtype, dim=-1) test.check(op, *test.gen_inputs(), compare=_exact_compare) @@ -553,7 +553,7 @@ def test_all_smoke_float16(m: int, n: int, dtype: torch.dtype) -> None: from tileops.ops.reduction.all_op import AllFwdOp test = LogicalReduceTest(m, n, dtype, "all") - op = AllFwdOp(dtype=dtype) + op = AllFwdOp(dtype=dtype, dim=-1) test.check(op, *test.gen_inputs(), compare=_exact_compare) @@ -562,7 +562,7 @@ def test_all_smoke_bfloat16(m: int, n: int, dtype: torch.dtype) -> None: from tileops.ops.reduction.all_op import AllFwdOp test = LogicalReduceTest(m, n, dtype, "all") - op = AllFwdOp(dtype=dtype) + op = AllFwdOp(dtype=dtype, dim=-1) test.check(op, *test.gen_inputs(), compare=_exact_compare) @@ -571,7 +571,7 @@ def test_all_smoke_int32(m: int, n: int, dtype: torch.dtype) -> None: from tileops.ops.reduction.all_op import AllFwdOp test = LogicalReduceTest(m, n, dtype, "all") - op = AllFwdOp(dtype=dtype) + op = AllFwdOp(dtype=dtype, dim=-1) test.check(op, *test.gen_inputs(), compare=_exact_compare) @@ -580,7 +580,7 @@ def test_all_smoke_int64(m: int, n: int, dtype: torch.dtype) -> None: from tileops.ops.reduction.all_op import AllFwdOp test = LogicalReduceTest(m, n, dtype, "all") - op = AllFwdOp(dtype=dtype) + op = AllFwdOp(dtype=dtype, dim=-1) test.check(op, *test.gen_inputs(), compare=_exact_compare) @@ -589,7 +589,7 @@ def test_all_smoke_bool(m: int, n: int, dtype: torch.dtype) -> None: from tileops.ops.reduction.all_op import AllFwdOp test = LogicalReduceTest(m, n, dtype, "all") - op = AllFwdOp(dtype=dtype) + op = AllFwdOp(dtype=dtype, dim=-1) test.check(op, *test.gen_inputs(), compare=_exact_compare) @@ -598,7 +598,7 @@ def test_count_nonzero_smoke_float16(m: int, n: int, dtype: torch.dtype) -> None from tileops.ops.reduction.count_nonzero import CountNonzeroFwdOp test = LogicalReduceTest(m, n, dtype, "count_nonzero") - op = CountNonzeroFwdOp(dtype=dtype) + op = CountNonzeroFwdOp(dtype=dtype, dim=-1) test.check(op, *test.gen_inputs(), compare=_exact_compare_int64) @@ -607,7 +607,7 @@ def test_count_nonzero_smoke_bfloat16(m: int, n: int, dtype: torch.dtype) -> Non from tileops.ops.reduction.count_nonzero import CountNonzeroFwdOp test = LogicalReduceTest(m, n, dtype, "count_nonzero") - op = CountNonzeroFwdOp(dtype=dtype) + op = CountNonzeroFwdOp(dtype=dtype, dim=-1) test.check(op, *test.gen_inputs(), compare=_exact_compare_int64) @@ -616,7 +616,7 @@ def test_count_nonzero_smoke_int32(m: int, n: int, dtype: torch.dtype) -> None: from tileops.ops.reduction.count_nonzero import CountNonzeroFwdOp test = LogicalReduceTest(m, n, dtype, "count_nonzero") - op = CountNonzeroFwdOp(dtype=dtype) + op = CountNonzeroFwdOp(dtype=dtype, dim=-1) test.check(op, *test.gen_inputs(), compare=_exact_compare_int64) @@ -625,7 +625,7 @@ def test_count_nonzero_smoke_int64(m: int, n: int, dtype: torch.dtype) -> None: from tileops.ops.reduction.count_nonzero import CountNonzeroFwdOp test = LogicalReduceTest(m, n, dtype, "count_nonzero") - op = CountNonzeroFwdOp(dtype=dtype) + op = CountNonzeroFwdOp(dtype=dtype, dim=-1) test.check(op, *test.gen_inputs(), compare=_exact_compare_int64) @@ -634,7 +634,7 @@ def test_count_nonzero_smoke_bool(m: int, n: int, dtype: torch.dtype) -> None: from tileops.ops.reduction.count_nonzero import CountNonzeroFwdOp test = LogicalReduceTest(m, n, dtype, "count_nonzero") - op = CountNonzeroFwdOp(dtype=dtype) + op = CountNonzeroFwdOp(dtype=dtype, dim=-1) test.check(op, *test.gen_inputs(), compare=_exact_compare_int64) @@ -661,6 +661,7 @@ def test_logical_reduce_long_sequence_tiled(op_kind: str, dtype: torch.dtype) -> test = LogicalReduceTest(3, 33024, dtype, op_kind) op = op_map[op_kind]( dtype=dtype, + dim=-1, kernel_map={"logical_reduce": _TailBlockLogicalReduceKernel}, ) compare = _exact_compare_int64 if op_kind == "count_nonzero" else _exact_compare diff --git a/tests/ops/test_reduce.py b/tests/ops/test_reduce.py index 9c77fb444..658fc6763 100644 --- a/tests/ops/test_reduce.py +++ b/tests/ops/test_reduce.py @@ -202,7 +202,7 @@ def test_sum_op(m: int, n: int, dtype: torch.dtype) -> None: from tileops.ops.reduction.reduce import SumFwdOp test = ReduceTest(m, n, dtype, "sum") - op = SumFwdOp(dtype=dtype) + op = SumFwdOp(dtype=dtype, dim=-1) test.check(op, *test.gen_inputs(), **_tol(dtype)) @@ -211,7 +211,7 @@ def test_sum_tiled(m: int, n: int, dtype: torch.dtype) -> None: from tileops.ops.reduction.reduce import SumFwdOp test = ReduceTest(m, n, dtype, "sum") - op = SumFwdOp(dtype=dtype) + op = SumFwdOp(dtype=dtype, dim=-1) test.check(op, *test.gen_inputs(), **_tol(dtype)) @@ -220,7 +220,7 @@ def test_prod_tiled(m: int, n: int, dtype: torch.dtype) -> None: from tileops.ops.reduction.reduce import ProdFwdOp test = ProdTest(m, n, dtype) - op = ProdFwdOp(dtype=dtype) + op = ProdFwdOp(dtype=dtype, dim=-1) tol = {"atol": 5e-2, "rtol": 5e-2} if dtype != torch.float32 else {"atol": 1e-3, "rtol": 1e-3} test.check(op, *test.gen_inputs(), **tol) @@ -230,7 +230,7 @@ def test_var_tiled(m: int, n: int, dtype: torch.dtype) -> None: from tileops.ops.reduction.reduce import VarFwdOp test = WelfordTest(m, n, dtype, "var", correction=1) - op = VarFwdOp(dtype=dtype) + op = VarFwdOp(dtype=dtype, dim=-1) test.check(op, *test.gen_inputs(), **_tol(dtype)) @@ -240,7 +240,7 @@ def test_sum_non_contiguous(m: int, n: int, dtype: torch.dtype) -> None: x_full = torch.randn(m, n * 2, dtype=dtype, device="cuda") x = x_full[:, :n] - op = SumFwdOp(dtype=dtype) + op = SumFwdOp(dtype=dtype, dim=-1) ref = x.contiguous().float().sum(dim=-1).to(dtype) y = op(x) tol = _tol(dtype) @@ -252,7 +252,7 @@ def test_sum_3d(batch: int, seq: int, hidden: int, dtype: torch.dtype) -> None: from tileops.ops.reduction.reduce import SumFwdOp x = torch.randn(batch, seq, hidden, dtype=dtype, device="cuda") - op = SumFwdOp(dtype=dtype) + op = SumFwdOp(dtype=dtype, dim=-1) ref = x.float().sum(dim=-1).to(dtype) y = op(x) tol = _tol(dtype) @@ -264,7 +264,7 @@ def test_sum_4d(b0: int, b1: int, b2: int, n: int, dtype: torch.dtype) -> None: from tileops.ops.reduction.reduce import SumFwdOp x = torch.randn(b0, b1, b2, n, dtype=dtype, device="cuda") - op = SumFwdOp(dtype=dtype) + op = SumFwdOp(dtype=dtype, dim=-1) ref = x.float().sum(dim=-1).to(dtype) y = op(x) tol = _tol(dtype) @@ -281,7 +281,7 @@ def test_mean_op(m: int, n: int, dtype: torch.dtype) -> None: from tileops.ops.reduction.reduce import MeanFwdOp test = ReduceTest(m, n, dtype, "mean") - op = MeanFwdOp(dtype=dtype) + op = MeanFwdOp(dtype=dtype, dim=-1) test.check(op, *test.gen_inputs(), **_tol(dtype)) @@ -295,7 +295,7 @@ def test_amin_op(m: int, n: int, dtype: torch.dtype) -> None: from tileops.ops.reduction.reduce import AminFwdOp test = ReduceTest(m, n, dtype, "amin") - op = AminFwdOp(dtype=dtype) + op = AminFwdOp(dtype=dtype, dim=-1) test.check(op, *test.gen_inputs(), **_tol(dtype)) @@ -309,7 +309,7 @@ def test_amax_op(m: int, n: int, dtype: torch.dtype) -> None: from tileops.ops.reduction.reduce import AmaxFwdOp test = ReduceTest(m, n, dtype, "amax") - op = AmaxFwdOp(dtype=dtype) + op = AmaxFwdOp(dtype=dtype, dim=-1) test.check(op, *test.gen_inputs(), **_tol(dtype)) @@ -323,7 +323,7 @@ def test_prod_op(m: int, n: int, dtype: torch.dtype) -> None: from tileops.ops.reduction.reduce import ProdFwdOp test = ProdTest(m, n, dtype) - op = ProdFwdOp(dtype=dtype) + op = ProdFwdOp(dtype=dtype, dim=-1) # Prod is more numerically sensitive tol = {"atol": 5e-2, "rtol": 5e-2} if dtype != torch.float32 else {"atol": 1e-3, "rtol": 1e-3} test.check(op, *test.gen_inputs(), **tol) @@ -339,7 +339,7 @@ def test_std_op(m: int, n: int, dtype: torch.dtype) -> None: from tileops.ops.reduction.reduce import StdFwdOp test = WelfordTest(m, n, dtype, "std", correction=1) - op = StdFwdOp(dtype=dtype) + op = StdFwdOp(dtype=dtype, dim=-1) test.check(op, *test.gen_inputs(), **_tol(dtype)) @@ -348,7 +348,7 @@ def test_std_bessel(m: int, n: int, dtype: torch.dtype, correction: int) -> None from tileops.ops.reduction.reduce import StdFwdOp test = WelfordTest(m, n, dtype, "std", correction=correction) - op = StdFwdOp(dtype=dtype, correction=correction) + op = StdFwdOp(dtype=dtype, correction=correction, dim=-1) test.check(op, *test.gen_inputs(), **_tol(dtype)) @@ -362,7 +362,7 @@ def test_var_op(m: int, n: int, dtype: torch.dtype) -> None: from tileops.ops.reduction.reduce import VarFwdOp test = WelfordTest(m, n, dtype, "var", correction=1) - op = VarFwdOp(dtype=dtype) + op = VarFwdOp(dtype=dtype, dim=-1) test.check(op, *test.gen_inputs(), **_tol(dtype)) @@ -371,7 +371,7 @@ def test_var_bessel(m: int, n: int, dtype: torch.dtype, correction: int) -> None from tileops.ops.reduction.reduce import VarFwdOp test = WelfordTest(m, n, dtype, "var", correction=correction) - op = VarFwdOp(dtype=dtype, correction=correction) + op = VarFwdOp(dtype=dtype, correction=correction, dim=-1) test.check(op, *test.gen_inputs(), **_tol(dtype)) @@ -385,7 +385,7 @@ def test_var_mean_op(m: int, n: int, dtype: torch.dtype) -> None: from tileops.ops.reduction.reduce import VarMeanFwdOp test = WelfordTest(m, n, dtype, "var_mean", correction=1) - op = VarMeanFwdOp(dtype=dtype) + op = VarMeanFwdOp(dtype=dtype, dim=-1) test.check(op, *test.gen_inputs(), **_tol(dtype)) @@ -394,7 +394,7 @@ def test_var_mean_bessel(m: int, n: int, dtype: torch.dtype, correction: int) -> from tileops.ops.reduction.reduce import VarMeanFwdOp test = WelfordTest(m, n, dtype, "var_mean", correction=correction) - op = VarMeanFwdOp(dtype=dtype, correction=correction) + op = VarMeanFwdOp(dtype=dtype, correction=correction, dim=-1) test.check(op, *test.gen_inputs(), **_tol(dtype)) @@ -408,7 +408,7 @@ def test_var_3d(batch: int, seq: int, hidden: int, dtype: torch.dtype) -> None: from tileops.ops.reduction.reduce import VarFwdOp x = torch.randn(batch, seq, hidden, dtype=dtype, device="cuda") - op = VarFwdOp(dtype=dtype) + op = VarFwdOp(dtype=dtype, dim=-1) ref = x.float().var(dim=-1, correction=1).to(dtype) y = op(x) tol = _tol(dtype) @@ -420,7 +420,7 @@ def test_std_3d(batch: int, seq: int, hidden: int, dtype: torch.dtype) -> None: from tileops.ops.reduction.reduce import StdFwdOp x = torch.randn(batch, seq, hidden, dtype=dtype, device="cuda") - op = StdFwdOp(dtype=dtype) + op = StdFwdOp(dtype=dtype, dim=-1) ref = x.float().std(dim=-1, correction=1).to(dtype) y = op(x) tol = _tol(dtype) @@ -437,7 +437,7 @@ def test_sum_1d(n: int, dtype: torch.dtype) -> None: from tileops.ops.reduction.reduce import SumFwdOp x = torch.randn(n, dtype=dtype, device="cuda") - op = SumFwdOp(dtype=dtype) + op = SumFwdOp(dtype=dtype, dim=-1) ref = x.float().sum(dim=-1).to(dtype) y = op(x) tol = _tol(dtype) @@ -451,7 +451,7 @@ def test_var_1d(n: int, dtype: torch.dtype) -> None: from tileops.ops.reduction.reduce import VarFwdOp x = torch.randn(n, dtype=dtype, device="cuda") - op = VarFwdOp(dtype=dtype) + op = VarFwdOp(dtype=dtype, dim=-1) ref = x.float().var(dim=-1, correction=1).to(dtype) y = op(x) tol = _tol(dtype) @@ -471,7 +471,7 @@ def test_var_non_contiguous(m: int, n: int, dtype: torch.dtype) -> None: x_full = torch.randn(m, n * 2, dtype=dtype, device="cuda") x = x_full[:, :n] - op = VarFwdOp(dtype=dtype) + op = VarFwdOp(dtype=dtype, dim=-1) ref = x.contiguous().float().var(dim=-1, correction=1).to(dtype) y = op(x) tol = _tol(dtype) @@ -484,7 +484,7 @@ def test_std_non_contiguous(m: int, n: int, dtype: torch.dtype) -> None: x_full = torch.randn(m, n * 2, dtype=dtype, device="cuda") x = x_full[:, :n] - op = StdFwdOp(dtype=dtype) + op = StdFwdOp(dtype=dtype, dim=-1) ref = x.contiguous().float().std(dim=-1, correction=1).to(dtype) y = op(x) tol = _tol(dtype) diff --git a/tests/ops/test_reduce_dim_none.py b/tests/ops/test_reduce_dim_none.py index 56d30d313..1284d8b48 100644 --- a/tests/ops/test_reduce_dim_none.py +++ b/tests/ops/test_reduce_dim_none.py @@ -151,24 +151,15 @@ def test_amin_dim_none( assert torch.allclose(y, ref, **tol), f"max err: {(y - ref).abs().max()}" -@DimNoneFixture -def test_prod_dim_none( - shape: tuple, keepdim: bool, dtype: torch.dtype, -) -> None: +@pytest.mark.smoke +def test_prod_dim_none_rejected() -> None: + """ProdFwdOp narrows ``dim`` to ``int`` per its manifest signature, so + ``dim=None`` (the full-reduction overload offered by the base) is + rejected at construction time.""" from tileops.ops.reduction.reduce import ProdFwdOp - # Use small uniform values to avoid overflow/underflow in prod - x = torch.rand(*shape, dtype=dtype, device="cuda") * 0.5 + 0.75 - op = ProdFwdOp(dtype=dtype, dim=None, keepdim=keepdim) - # PyTorch prod: reduce all dims manually (prod doesn't accept list[int]) - ref = x.float() - for d in sorted(_all_dims(shape), reverse=True): - ref = torch.prod(ref, dim=d, keepdim=keepdim) - ref = ref.to(dtype) - y = op(x) - tol = _tol(dtype) - assert y.shape == ref.shape, f"shape mismatch: {y.shape} vs {ref.shape}" - assert torch.allclose(y, ref, **tol), f"max err: {(y - ref).abs().max()}" + with pytest.raises(TypeError, match="ProdFwdOp.dim must be int"): + ProdFwdOp(dtype=torch.float16, dim=None) # --------------------------------------------------------------------------- diff --git a/tests/ops/test_reduce_multidim.py b/tests/ops/test_reduce_multidim.py index 77af4f8fa..dcd06c08d 100644 --- a/tests/ops/test_reduce_multidim.py +++ b/tests/ops/test_reduce_multidim.py @@ -112,23 +112,17 @@ def test_amax_multidim( assert torch.allclose(y, ref, **tol), f"max err: {(y - ref).abs().max()}" -@MultiDimFixture -def test_prod_multidim( - shape: tuple, dims: list, keepdim: bool, dtype: torch.dtype, -) -> None: +@pytest.mark.smoke +def test_prod_multidim_rejected() -> None: + """ProdFwdOp narrows ``dim`` to ``int`` per its manifest signature, so + the multi-dim (``list[int]`` / ``tuple[int, ...]``) overload is rejected + at construction time.""" from tileops.ops.reduction.reduce import ProdFwdOp - x = torch.randn(*shape, dtype=dtype, device="cuda") - op = ProdFwdOp(dtype=dtype, dim=dims, keepdim=keepdim) - # PyTorch doesn't support list[int] for prod, so iterate dims manually. - ref = x.float() - for d in sorted(dims, reverse=True): - ref = torch.prod(ref, dim=d, keepdim=keepdim) - ref = ref.to(dtype) - y = op(x) - tol = _tol(dtype) - assert y.shape == ref.shape, f"shape mismatch: {y.shape} vs {ref.shape}" - assert torch.allclose(y, ref, **tol), f"max err: {(y - ref).abs().max()}" + with pytest.raises(TypeError, match="ProdFwdOp.dim must be int"): + ProdFwdOp(dtype=torch.float16, dim=[0, 1]) + with pytest.raises(TypeError, match="ProdFwdOp.dim must be int"): + ProdFwdOp(dtype=torch.float16, dim=(0, 1)) @MultiDimFixture @@ -476,12 +470,13 @@ def test_var_mean_empty_dim_full_reduction() -> None: @pytest.mark.smoke def test_prod_empty_dim_rejects() -> None: + """ProdFwdOp narrows ``dim`` to ``int`` per its manifest signature, so + ``dim=[]`` is rejected by ``_validate_dim`` at construction (before + reaching the base class's ``empty_dim_policy`` branch).""" from tileops.ops.reduction.reduce import ProdFwdOp - x = torch.randn(2, 3, 4, dtype=torch.float16, device="cuda") - op = ProdFwdOp(dtype=torch.float16, dim=[], keepdim=False) - with pytest.raises(ValueError, match="dim=\\[\\] is not supported"): - op(x) + with pytest.raises(TypeError, match="ProdFwdOp.dim must be int"): + ProdFwdOp(dtype=torch.float16, dim=[], keepdim=False) @pytest.mark.smoke @@ -495,13 +490,17 @@ def test_logsumexp_empty_dim_rejects() -> None: @pytest.mark.smoke -def test_all_empty_dim_rejects() -> None: +def test_all_empty_dim_is_noop() -> None: + """AllFwdOp honors the spec's ``dim=[]`` no-op contract: output equals + ``x.bool()`` with the input shape.""" from tileops.ops.reduction.all_op import AllFwdOp x = (torch.randn(2, 3, 4, device="cuda") > 0).to(torch.float16) op = AllFwdOp(dtype=torch.float16, dim=[], keepdim=False) - with pytest.raises(ValueError, match="dim=\\[\\] is not supported"): - op(x) + y = op(x) + assert y.shape == x.shape + assert y.dtype == torch.bool + assert torch.equal(y, x.bool()) @pytest.mark.smoke diff --git a/tests/ops/test_reduction_defaults.py b/tests/ops/test_reduction_defaults.py new file mode 100644 index 000000000..da2d22d2f --- /dev/null +++ b/tests/ops/test_reduction_defaults.py @@ -0,0 +1,359 @@ +"""Regression tests for reduction-op constructor defaults and empty-dim semantics. + +Pins two manifest-conformance invariants for the reduction op family: + +1. For the ten ops whose manifest declares ``default: null`` on ``dim`` + (Sum/Mean/Amax/Amin/Var/Std/VarMean/All/Any/CountNonzero), constructing + the op with only ``dtype=`` performs a full reduction (output shape + equals ``torch.(x).shape``). ``ProdFwdOp`` keeps its documented + ``dim=-1`` default. + +2. ``AllFwdOp`` / ``AnyFwdOp`` honor the spec's ``dim=[]`` / ``dim=()`` + no-op contract: output shape equals the input shape, output dtype is + ``bool``, and values equal ``x.bool()``. +""" + +from __future__ import annotations + +import pytest +import torch + +pytestmark = pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA required" +) + + +_FLOAT_SHAPE = (2, 4, 8) +_LOGICAL_SHAPE = (2, 4, 8) + + +def _make_float(shape: tuple, dtype: torch.dtype) -> torch.Tensor: + return torch.randn(*shape, dtype=dtype, device="cuda") + + +def _make_logical(shape: tuple, dtype: torch.dtype) -> torch.Tensor: + # values in {-1, 0, 1} so .bool() has both T and F. + return (torch.randint(-1, 2, shape, device="cuda")).to(dtype) + + +# --------------------------------------------------------------------------- +# AC-3: default dim=None for the ten ops -> full reduction on 3-D input +# --------------------------------------------------------------------------- + + +@pytest.mark.smoke +def test_sum_default_dim_full_reduction() -> None: + from tileops.ops.reduction.reduce import SumFwdOp + + x = _make_float(_FLOAT_SHAPE, torch.float16) + op = SumFwdOp(dtype=torch.float16) + y = op(x) + assert y.shape == torch.sum(x).shape + + +@pytest.mark.smoke +def test_mean_default_dim_full_reduction() -> None: + from tileops.ops.reduction.reduce import MeanFwdOp + + x = _make_float(_FLOAT_SHAPE, torch.float16) + op = MeanFwdOp(dtype=torch.float16) + y = op(x) + assert y.shape == torch.mean(x).shape + + +@pytest.mark.smoke +def test_amax_default_dim_full_reduction() -> None: + from tileops.ops.reduction.reduce import AmaxFwdOp + + x = _make_float(_FLOAT_SHAPE, torch.float16) + op = AmaxFwdOp(dtype=torch.float16) + y = op(x) + assert y.shape == torch.amax(x).shape + + +@pytest.mark.smoke +def test_amin_default_dim_full_reduction() -> None: + from tileops.ops.reduction.reduce import AminFwdOp + + x = _make_float(_FLOAT_SHAPE, torch.float16) + op = AminFwdOp(dtype=torch.float16) + y = op(x) + assert y.shape == torch.amin(x).shape + + +@pytest.mark.smoke +def test_var_default_dim_full_reduction() -> None: + from tileops.ops.reduction.reduce import VarFwdOp + + x = _make_float(_FLOAT_SHAPE, torch.float16) + op = VarFwdOp(dtype=torch.float16) + y = op(x) + assert y.shape == torch.var(x).shape + + +@pytest.mark.smoke +def test_std_default_dim_full_reduction() -> None: + from tileops.ops.reduction.reduce import StdFwdOp + + x = _make_float(_FLOAT_SHAPE, torch.float16) + op = StdFwdOp(dtype=torch.float16) + y = op(x) + assert y.shape == torch.std(x).shape + + +@pytest.mark.smoke +def test_var_mean_default_dim_full_reduction() -> None: + from tileops.ops.reduction.reduce import VarMeanFwdOp + + x = _make_float(_FLOAT_SHAPE, torch.float16) + op = VarMeanFwdOp(dtype=torch.float16) + var_out, mean_out = op(x) + ref_var, ref_mean = torch.var_mean(x) + assert var_out.shape == ref_var.shape + assert mean_out.shape == ref_mean.shape + + +@pytest.mark.smoke +def test_all_default_dim_full_reduction() -> None: + from tileops.ops.reduction.all_op import AllFwdOp + + x = _make_logical(_LOGICAL_SHAPE, torch.float16) + op = AllFwdOp(dtype=torch.float16) + y = op(x) + assert y.shape == torch.all(x.bool()).shape + assert y.dtype == torch.bool + + +@pytest.mark.smoke +def test_any_default_dim_full_reduction() -> None: + from tileops.ops.reduction.any_op import AnyFwdOp + + x = _make_logical(_LOGICAL_SHAPE, torch.float16) + op = AnyFwdOp(dtype=torch.float16) + y = op(x) + assert y.shape == torch.any(x.bool()).shape + assert y.dtype == torch.bool + + +@pytest.mark.smoke +def test_count_nonzero_default_dim_full_reduction() -> None: + from tileops.ops.reduction.count_nonzero import CountNonzeroFwdOp + + x = _make_logical(_LOGICAL_SHAPE, torch.float16) + op = CountNonzeroFwdOp(dtype=torch.float16) + y = op(x) + assert y.shape == torch.count_nonzero(x).shape + assert y.dtype == torch.int64 + + +# --------------------------------------------------------------------------- +# AC-4: ProdFwdOp keeps documented dim=-1 default +# --------------------------------------------------------------------------- + + +@pytest.mark.smoke +def test_prod_default_dim_last_axis() -> None: + from tileops.ops.reduction.reduce import ProdFwdOp + + # use a narrow value range so fp16 prod is numerically stable + x = torch.rand(*_FLOAT_SHAPE, dtype=torch.float16, device="cuda") * 0.01 + 0.99 + op = ProdFwdOp(dtype=torch.float16) + y = op(x) + assert y.shape == torch.prod(x, dim=-1).shape + + +# --------------------------------------------------------------------------- +# AC-5: AllFwdOp/AnyFwdOp dim=[] / dim=() noop contract +# --------------------------------------------------------------------------- + + +@pytest.mark.smoke +@pytest.mark.parametrize("empty_dim", [[], ()]) +def test_all_empty_dim_noop(empty_dim) -> None: + from tileops.ops.reduction.all_op import AllFwdOp + + x = _make_logical(_LOGICAL_SHAPE, torch.float16) + op = AllFwdOp(dtype=torch.float16, dim=empty_dim) + y = op(x) + assert y.shape == x.shape + assert y.dtype == torch.bool + assert torch.equal(y, x.bool()) + + +@pytest.mark.smoke +@pytest.mark.parametrize("empty_dim", [[], ()]) +def test_any_empty_dim_noop(empty_dim) -> None: + from tileops.ops.reduction.any_op import AnyFwdOp + + x = _make_logical(_LOGICAL_SHAPE, torch.float16) + op = AnyFwdOp(dtype=torch.float16, dim=empty_dim) + y = op(x) + assert y.shape == x.shape + assert y.dtype == torch.bool + assert torch.equal(y, x.bool()) + + +# --------------------------------------------------------------------------- +# AC-6: normalize_dim noop policy returns [] +# --------------------------------------------------------------------------- + + +@pytest.mark.smoke +def test_normalize_dim_noop_returns_empty() -> None: + from tileops.ops.reduction._multidim import normalize_dim + + assert normalize_dim([], ndim=3, empty_dim_policy="noop") == [] + assert normalize_dim((), ndim=3, empty_dim_policy="noop") == [] + + +@pytest.mark.smoke +def test_normalize_dim_reject_raises_on_empty() -> None: + from tileops.ops.reduction._multidim import normalize_dim + + with pytest.raises(ValueError): + normalize_dim([], ndim=3, empty_dim_policy="reject") + + +@pytest.mark.smoke +def test_normalize_dim_full_returns_all() -> None: + from tileops.ops.reduction._multidim import normalize_dim + + assert normalize_dim([], ndim=3, empty_dim_policy="full") == [0, 1, 2] + + +@pytest.mark.smoke +def test_empty_dim_policy_class_attrs() -> None: + """AC-6: per-op empty_dim_policy bindings.""" + from tileops.ops.reduction.all_op import AllFwdOp + from tileops.ops.reduction.any_op import AnyFwdOp + from tileops.ops.reduction.count_nonzero import CountNonzeroFwdOp + from tileops.ops.reduction.reduce import ( + AmaxFwdOp, + AminFwdOp, + MeanFwdOp, + ProdFwdOp, + StdFwdOp, + SumFwdOp, + VarFwdOp, + VarMeanFwdOp, + _ReduceOpBase, + ) + + assert _ReduceOpBase._empty_dim_policy == "reject" + assert AllFwdOp._empty_dim_policy == "noop" + assert AnyFwdOp._empty_dim_policy == "noop" + for cls in ( + SumFwdOp, MeanFwdOp, AmaxFwdOp, AminFwdOp, + StdFwdOp, VarFwdOp, VarMeanFwdOp, CountNonzeroFwdOp, + ): + assert cls._empty_dim_policy == "full", cls.__name__ + # ProdFwdOp inherits default (reject); empty dim is not in its contract + assert ProdFwdOp._empty_dim_policy == "reject" + + +# --------------------------------------------------------------------------- +# Empty-dim noop must NOT bypass input validation or roofline binding +# --------------------------------------------------------------------------- + + +@pytest.mark.smoke +def test_all_empty_dim_noop_rejects_cpu_tensor() -> None: + """dim=[] must still validate device; non-CUDA input must raise.""" + from tileops.ops.reduction.all_op import AllFwdOp + + x = (torch.randint(-1, 2, _LOGICAL_SHAPE)).to(torch.float16) # cpu + op = AllFwdOp(dtype=torch.float16, dim=[]) + with pytest.raises(ValueError, match="CUDA tensor"): + op(x) + + +@pytest.mark.smoke +def test_any_empty_dim_noop_rejects_cpu_tensor() -> None: + from tileops.ops.reduction.any_op import AnyFwdOp + + x = (torch.randint(-1, 2, _LOGICAL_SHAPE)).to(torch.float16) # cpu + op = AnyFwdOp(dtype=torch.float16, dim=[]) + with pytest.raises(ValueError, match="CUDA tensor"): + op(x) + + +@pytest.mark.smoke +def test_all_empty_dim_noop_rejects_wrong_dtype() -> None: + """dim=[] must still validate dtype against the op's declared dtype.""" + from tileops.ops.reduction.all_op import AllFwdOp + + x = _make_logical(_LOGICAL_SHAPE, torch.float32) # cuda, fp32 + op = AllFwdOp(dtype=torch.float16, dim=[]) + with pytest.raises(ValueError, match="Expected x.dtype"): + op(x) + + +@pytest.mark.smoke +def test_any_empty_dim_noop_rejects_wrong_dtype() -> None: + from tileops.ops.reduction.any_op import AnyFwdOp + + x = _make_logical(_LOGICAL_SHAPE, torch.float32) + op = AnyFwdOp(dtype=torch.float16, dim=[]) + with pytest.raises(ValueError, match="Expected x.dtype"): + op(x) + + +@pytest.mark.smoke +def test_all_empty_dim_noop_binds_roofline() -> None: + """eval_roofline() must succeed after a dim=[] noop forward and + report non-zero data-movement (the noop still reads the input and + writes an equal-shape cast result).""" + from tileops.ops.reduction.all_op import AllFwdOp + + x = _make_logical(_LOGICAL_SHAPE, torch.float16) + op = AllFwdOp(dtype=torch.float16, dim=[]) + op(x) + flops, mem_bytes = op.eval_roofline() + numel = x.numel() + elem_bytes = x.element_size() + # Noop binds (M=numel, N=1); for the "all" op_kind this gives + # mem_bytes = numel * elem_bytes + numel (input read + bool write). + expected_lower = numel * elem_bytes + expected_upper = 2 * numel * elem_bytes + numel + assert mem_bytes >= expected_lower, ( + f"noop bandwidth {mem_bytes} under-counts input read " + f"({expected_lower} bytes)" + ) + assert mem_bytes <= expected_upper + # flops are degenerate (one op per element); contract is non-negative. + assert flops >= 0 + + +@pytest.mark.smoke +def test_any_empty_dim_noop_binds_roofline() -> None: + from tileops.ops.reduction.any_op import AnyFwdOp + + x = _make_logical(_LOGICAL_SHAPE, torch.float16) + op = AnyFwdOp(dtype=torch.float16, dim=[]) + op(x) + flops, mem_bytes = op.eval_roofline() + numel = x.numel() + elem_bytes = x.element_size() + expected_lower = numel * elem_bytes + expected_upper = 2 * numel * elem_bytes + numel + assert mem_bytes >= expected_lower + assert mem_bytes <= expected_upper + assert flops >= 0 + +@pytest.mark.smoke +def test_validate_dim_rejects_bool_scalar() -> None: + """`bool` subclasses `int`, but a boolean dim is never a valid axis; + `_validate_dim` must reject it explicitly.""" + from tileops.ops.reduction.reduce import SumFwdOp + + with pytest.raises(TypeError, match="dim must not be bool"): + SumFwdOp(dtype=torch.float16, dim=True) + + +@pytest.mark.smoke +def test_validate_dim_rejects_bool_in_list() -> None: + """Same guard applies element-wise to `list[int]` / `tuple[int, ...]`.""" + from tileops.ops.reduction.reduce import SumFwdOp + + with pytest.raises(TypeError, match="must be int .not bool"): + SumFwdOp(dtype=torch.float16, dim=[True, 0]) diff --git a/tests/ops/test_welford_non_aligned.py b/tests/ops/test_welford_non_aligned.py index 7a137bac4..2d962d0e2 100644 --- a/tests/ops/test_welford_non_aligned.py +++ b/tests/ops/test_welford_non_aligned.py @@ -201,7 +201,7 @@ def test_var_non_aligned(m: int, n: int, dtype: torch.dtype) -> None: from tileops.ops.reduction.reduce import VarFwdOp test = WelfordNonAlignedTest((m, n), dtype, "var", correction=1) - op = VarFwdOp(dtype=dtype) + op = VarFwdOp(dim=-1, dtype=dtype) test.check(op, *test.gen_inputs(), **_tol(dtype)) @@ -216,7 +216,7 @@ def test_std_non_aligned(m: int, n: int, dtype: torch.dtype) -> None: from tileops.ops.reduction.reduce import StdFwdOp test = WelfordNonAlignedTest((m, n), dtype, "std", correction=1) - op = StdFwdOp(dtype=dtype) + op = StdFwdOp(dim=-1, dtype=dtype) test.check(op, *test.gen_inputs(), **_tol(dtype)) @@ -231,7 +231,7 @@ def test_var_mean_non_aligned(m: int, n: int, dtype: torch.dtype) -> None: from tileops.ops.reduction.reduce import VarMeanFwdOp test = WelfordNonAlignedTest((m, n), dtype, "var_mean", correction=1) - op = VarMeanFwdOp(dtype=dtype) + op = VarMeanFwdOp(dim=-1, dtype=dtype) test.check(op, *test.gen_inputs(), **_tol(dtype)) @@ -246,7 +246,7 @@ def test_var_3d_non_aligned(batch: int, seq: int, hidden: int, dtype: torch.dtyp from tileops.ops.reduction.reduce import VarFwdOp x = torch.randn(batch, seq, hidden, dtype=dtype, device="cuda") - op = VarFwdOp(dtype=dtype) + op = VarFwdOp(dim=-1, dtype=dtype) ref = x.float().var(dim=-1, correction=1).to(dtype) y = op(x) tol = _tol(dtype) @@ -259,7 +259,7 @@ def test_std_3d_non_aligned(batch: int, seq: int, hidden: int, dtype: torch.dtyp from tileops.ops.reduction.reduce import StdFwdOp x = torch.randn(batch, seq, hidden, dtype=dtype, device="cuda") - op = StdFwdOp(dtype=dtype) + op = StdFwdOp(dim=-1, dtype=dtype) ref = x.float().std(dim=-1, correction=1).to(dtype) y = op(x) tol = _tol(dtype) @@ -272,7 +272,7 @@ def test_var_mean_3d_non_aligned(batch: int, seq: int, hidden: int, dtype: torch from tileops.ops.reduction.reduce import VarMeanFwdOp x = torch.randn(batch, seq, hidden, dtype=dtype, device="cuda") - op = VarMeanFwdOp(dtype=dtype, correction=1) + op = VarMeanFwdOp(dim=-1, dtype=dtype, correction=1) ref_var = x.float().var(dim=-1, correction=1).to(dtype) ref_mean = x.float().mean(dim=-1).to(dtype) var_out, mean_out = op(x) diff --git a/tileops/ops/reduction/_multidim.py b/tileops/ops/reduction/_multidim.py index d62645190..ba675c72d 100644 --- a/tileops/ops/reduction/_multidim.py +++ b/tileops/ops/reduction/_multidim.py @@ -21,7 +21,7 @@ "restore_multidim_shape", ] -EmptyDimPolicy = Literal["reject", "full"] +EmptyDimPolicy = Literal["reject", "full", "noop"] def normalize_dim( @@ -36,12 +36,16 @@ def normalize_dim( dim: Single int, list of ints, or ``None`` (reduce all dims). ndim: Number of dimensions in the input tensor. empty_dim_policy: ``"reject"`` (default) raises on ``dim=[] / ()``; - ``"full"`` returns ``list(range(ndim))``. Each op opts in - explicitly because shared callers have different empty-dim - contracts. + ``"full"`` returns ``list(range(ndim))``; ``"noop"`` returns + ``[]``, signaling the caller to short-circuit and return the + input unchanged (modulo manifest-declared output-dtype cast). + Each op opts in explicitly because shared callers have + different empty-dim contracts. Returns: - Sorted list of non-negative dim indices (ascending). + Sorted list of non-negative dim indices (ascending). An empty + list is returned only when ``empty_dim_policy="noop"`` and the + caller passed ``dim=[]`` / ``dim=()``. Raises: IndexError: If any dim is out of range. @@ -56,9 +60,15 @@ def normalize_dim( if len(dims) == 0: if empty_dim_policy == "full": return list(range(ndim)) + if empty_dim_policy == "noop": + # Caller MUST detect [] and short-circuit before entering kernel + # paths -- the kernel does not handle a zero-dim reduction. + return [] raise ValueError( "dim=[] is not supported by this op; pass " - "empty_dim_policy=\"full\" to opt in to full-reduction." + "empty_dim_policy=\"full\" to opt in to full-reduction " + "or empty_dim_policy=\"noop\" to opt in to the identity " + "(return-input) contract." ) normalized = [] diff --git a/tileops/ops/reduction/all_op.py b/tileops/ops/reduction/all_op.py index ac8f5ef5e..a67f65f57 100644 --- a/tileops/ops/reduction/all_op.py +++ b/tileops/ops/reduction/all_op.py @@ -25,6 +25,7 @@ to_logical_float32, ) +from ._multidim import EmptyDimPolicy from .reduce import _ReduceOpBase __all__ = ["AllFwdOp"] @@ -33,9 +34,9 @@ class AllFwdOp(_ReduceOpBase): """All reduction along ``dim``, returning bool. - Construction: ``AllFwdOp(dtype=..., dim=-1, keepdim=False)``. M and N are - derived from the input tensor at forward time, and kernels are cached - by ``(M, N)`` to avoid rebuilds. + Construction: ``AllFwdOp(dtype=..., dim=None, keepdim=False)``. M and N + are derived from the input tensor at forward time, and kernels are + cached by ``(M, N)`` to avoid rebuilds. Padded positions use 1 (True), which is neutral for AND/all. @@ -43,11 +44,15 @@ class AllFwdOp(_ReduceOpBase): types. Inputs with unsupported TileLang storage dtypes (bool, int32, int64, complex64, complex128) are pre-converted to float32 in forward(). + Empty-dim contract: ``dim=[]`` / ``dim=()`` is a no-op -- forward returns + ``x.bool()`` with the input shape, matching ``torch.all`` semantics. + Args: dtype: Input data type (float16, bfloat16, float32, int32, int64, bool, complex64, complex128). - dim: Reduction dimension (default -1). Accepts ``int`` or - ``list[int]`` for multi-dim reduction. + dim: Reduction dimension (default ``None``, i.e. full reduction). + Accepts ``int``, ``list[int]``, or ``tuple[int, ...]`` for + multi-dim reduction. keepdim: Whether to retain the reduced dimension as size 1. kernel_map: Optional custom kernel map. tune: Whether to autotune the kernel. @@ -57,16 +62,28 @@ class AllFwdOp(_ReduceOpBase): _kernel_key = "logical_reduce" _kernel_cls = LogicalReduceKernel _kernel_handles_padding = True + _empty_dim_policy: EmptyDimPolicy = "noop" def __init__( self, *, dtype: torch.dtype, - dim: Union[int, List[int], None] = -1, + dim: Union[int, List[int], Tuple[int, ...], None] = None, keepdim: bool = False, kernel_map: Optional[Dict[str, Kernel]] = None, tune: bool = False, ): + """Construct AllFwdOp. + + Args: + dtype: Input data type. + dim: Reduction dimension (default ``None``, i.e. full reduction). + Accepts ``int``, ``list[int]``, ``tuple[int, ...]``, or + ``None``. + keepdim: Whether to retain reduced dims as size 1. + kernel_map: Optional override for kernel dispatch. + tune: Whether to autotune (default ``False``). + """ super().__init__( dtype=dtype, dim=dim, keepdim=keepdim, kernel_map=kernel_map, tune=tune, @@ -76,6 +93,10 @@ def _pad_value(self) -> float: """Pad with 1 (True), neutral for AND/all.""" return 1.0 + def _noop_output_dtype(self) -> torch.dtype: + """All returns bool per manifest contract.""" + return torch.bool + def _pre_kernel(self, x: torch.Tensor) -> Tuple[torch.Tensor, object]: """Convert unsupported storage dtypes to float32.""" if x.dtype in _UNSUPPORTED_STORAGE_DTYPES: diff --git a/tileops/ops/reduction/any_op.py b/tileops/ops/reduction/any_op.py index 5a5589765..aa8a73072 100644 --- a/tileops/ops/reduction/any_op.py +++ b/tileops/ops/reduction/any_op.py @@ -25,6 +25,7 @@ to_logical_float32, ) +from ._multidim import EmptyDimPolicy from .reduce import _ReduceOpBase __all__ = ["AnyFwdOp"] @@ -33,9 +34,9 @@ class AnyFwdOp(_ReduceOpBase): """Any reduction along ``dim``, returning bool. - Construction: ``AnyFwdOp(dtype=..., dim=-1, keepdim=False)``. M and N are - derived from the input tensor at forward time, and kernels are cached - by ``(M, N)`` to avoid rebuilds. + Construction: ``AnyFwdOp(dtype=..., dim=None, keepdim=False)``. M and N + are derived from the input tensor at forward time, and kernels are + cached by ``(M, N)`` to avoid rebuilds. Padded positions use 0 (False), which is neutral for OR/any. @@ -43,11 +44,15 @@ class AnyFwdOp(_ReduceOpBase): types. Inputs with unsupported TileLang storage dtypes (bool, int32, int64, complex64, complex128) are pre-converted to float32 in forward(). + Empty-dim contract: ``dim=[]`` / ``dim=()`` is a no-op -- forward returns + ``x.bool()`` with the input shape, matching ``torch.any`` semantics. + Args: dtype: Input data type (float16, bfloat16, float32, int32, int64, bool, complex64, complex128). - dim: Reduction dimension (default -1). Accepts ``int`` or - ``list[int]`` for multi-dim reduction. + dim: Reduction dimension (default ``None``, i.e. full reduction). + Accepts ``int``, ``list[int]``, or ``tuple[int, ...]`` for + multi-dim reduction. keepdim: Whether to retain the reduced dimension as size 1. kernel_map: Optional custom kernel map. tune: Whether to autotune the kernel. @@ -57,16 +62,28 @@ class AnyFwdOp(_ReduceOpBase): _kernel_key = "logical_reduce" _kernel_cls = LogicalReduceKernel _kernel_handles_padding = True + _empty_dim_policy: EmptyDimPolicy = "noop" def __init__( self, *, dtype: torch.dtype, - dim: Union[int, List[int], None] = -1, + dim: Union[int, List[int], Tuple[int, ...], None] = None, keepdim: bool = False, kernel_map: Optional[Dict[str, Kernel]] = None, tune: bool = False, ): + """Construct AnyFwdOp. + + Args: + dtype: Input data type. + dim: Reduction dimension (default ``None``, i.e. full reduction). + Accepts ``int``, ``list[int]``, ``tuple[int, ...]``, or + ``None``. + keepdim: Whether to retain reduced dims as size 1. + kernel_map: Optional override for kernel dispatch. + tune: Whether to autotune (default ``False``). + """ super().__init__( dtype=dtype, dim=dim, keepdim=keepdim, kernel_map=kernel_map, tune=tune, @@ -76,6 +93,10 @@ def _pad_value(self) -> float: """Pad with 0 (False), neutral for OR/any.""" return 0.0 + def _noop_output_dtype(self) -> torch.dtype: + """Any returns bool per manifest contract.""" + return torch.bool + def _pre_kernel(self, x: torch.Tensor) -> Tuple[torch.Tensor, object]: """Convert unsupported storage dtypes to float32.""" if x.dtype in _UNSUPPORTED_STORAGE_DTYPES: diff --git a/tileops/ops/reduction/count_nonzero.py b/tileops/ops/reduction/count_nonzero.py index aecfc8566..c25aafd87 100644 --- a/tileops/ops/reduction/count_nonzero.py +++ b/tileops/ops/reduction/count_nonzero.py @@ -37,7 +37,7 @@ class CountNonzeroFwdOp(_ReduceOpBase): """Count nonzero reduction along ``dim``, returning int64. - Construction: ``CountNonzeroFwdOp(dtype=..., dim=-1)``. M and N are + Construction: ``CountNonzeroFwdOp(dtype=..., dim=None)``. M and N are derived from the input tensor at forward time, and kernels are cached by ``(M, N)`` to avoid rebuilds. @@ -53,8 +53,9 @@ class CountNonzeroFwdOp(_ReduceOpBase): Args: dtype: Input data type (float16, bfloat16, float32, int32, int64, bool, complex64, complex128). - dim: Reduction dimension (default -1). Accepts ``int`` or - ``list[int]`` for multi-dim reduction. + dim: Reduction dimension (default ``None``, i.e. full reduction). + Accepts ``int``, ``list[int]``, or ``tuple[int, ...]`` for + multi-dim reduction. kernel_map: Optional custom kernel map. tune: Whether to autotune the kernel. """ @@ -69,7 +70,7 @@ def __init__( self, *, dtype: torch.dtype, - dim: Union[int, List[int], None] = -1, + dim: Union[int, List[int], Tuple[int, ...], None] = None, kernel_map: Optional[Dict[str, Kernel]] = None, tune: bool = False, ): diff --git a/tileops/ops/reduction/reduce.py b/tileops/ops/reduction/reduce.py index 245293f79..4c093c3e0 100644 --- a/tileops/ops/reduction/reduce.py +++ b/tileops/ops/reduction/reduce.py @@ -1,7 +1,10 @@ """Reduce ops: SumFwdOp, MeanFwdOp, AminFwdOp, AmaxFwdOp, ProdFwdOp, StdFwdOp, VarFwdOp, VarMeanFwdOp. Each op reduces along the configured ``dim`` and supports arbitrary-rank input. -The ``dim`` parameter accepts ``int`` or ``list[int]`` for multi-dim reduction. +The ``dim`` parameter accepts ``int``, ``list[int]``, or ``tuple[int, ...]`` +for multi-dim reduction. Constructor ``dim`` defaults to ``None`` (full +reduction) for the ten ops whose manifest declares ``default: null``; +``ProdFwdOp`` preserves ``dim=-1``. The Op layer validates inputs, reshapes to 2D (M, N), and calls the kernel. For simple, Welford, logical reduce, and vector norm ops, alignment padding is handled inside the kernel via masked loads with identity-element fills, @@ -75,11 +78,22 @@ def __init__( self, *, dtype: torch.dtype, - dim: Union[int, List[int], None] = -1, + dim: Union[int, List[int], Tuple[int, ...], None] = None, keepdim: bool = False, kernel_map: Optional[Dict[str, Kernel]] = None, tune: bool = False, ): + """Construct a reduce op. + + Args: + dtype: Input data type. + dim: Reduction dimension (default ``None``, i.e. full reduction). + Accepts ``int``, ``list[int]``, ``tuple[int, ...]``, or + ``None``. + keepdim: Whether to retain reduced dims as size 1. + kernel_map: Optional override for kernel dispatch. + tune: Whether to autotune (default ``False``). + """ self.dtype = dtype self.dim = dim self.keepdim = keepdim @@ -99,18 +113,31 @@ def _validate_dim(self) -> None: Default: accept ``int``, ``list[int]``/``tuple[int]``, or ``None``. Subclasses that only support single-dim reduction (e.g. argreduce) should override to reject non-scalar values. + + ``bool`` values are rejected explicitly. Python's ``bool`` subclasses + ``int`` (so ``isinstance(True, int)`` is true), but a boolean dim has + no meaningful interpretation as a tensor axis and almost always + signals a caller bug. """ dim = self.dim + if isinstance(dim, bool): + raise TypeError( + f"dim must not be bool (subclasses int but is not a valid " + f"axis), got {dim!r}" + ) if dim is None or isinstance(dim, int): return if isinstance(dim, (list, tuple)): - if not all(isinstance(d, int) for d in dim): - raise TypeError( - f"All elements of dim must be int, got {dim}" - ) + for d in dim: + if isinstance(d, bool) or not isinstance(d, int): + raise TypeError( + f"All elements of dim must be int (not bool), " + f"got {dim!r}" + ) return raise TypeError( - f"dim must be int, list[int], or None, got {type(dim).__name__}" + f"dim must be int, list[int], tuple[int, ...], or None, " + f"got {type(dim).__name__}" ) @property @@ -164,12 +191,73 @@ def _post_kernel(self, y: torch.Tensor, context: object) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor: """Run the reduce op on *x* along the configured dim.""" + noop_out = self._maybe_noop(x) + if noop_out is not None: + return noop_out x, orig_shape, dim_info, kernel = self._prepare_input(x) x, ctx = self._pre_kernel(x) y = kernel(x) y = self._post_kernel(y, ctx) return self._reshape_output(y, orig_shape, dim_info) + # ------------------------------------------------------------------ + # Empty-dim no-op short-circuit + # ------------------------------------------------------------------ + + def _noop_output_dtype(self) -> Optional[torch.dtype]: + """Manifest-declared output dtype for the no-op short-circuit. + + Subclasses with a fixed output dtype (e.g. All/Any -> bool) MUST + override so the short-circuit honors the manifest contract. The + default ``None`` means "preserve input dtype". + """ + return None + + def _validate_input_tensor(self, x: torch.Tensor) -> None: + """Validate device, dtype, and rank of the forward input. + + Shared by ``_prepare_input`` and the ``dim=[]`` noop short-circuit + so both paths enforce the same forward contract. + """ + if not x.is_cuda: + raise ValueError("x must be a CUDA tensor") + if x.dtype != self.dtype: + raise ValueError(f"Expected x.dtype {self.dtype}, got {x.dtype}") + if x.ndim == 0: + raise ValueError("Input tensor must be at least 1D") + + def _maybe_noop(self, x: torch.Tensor) -> Optional[torch.Tensor]: + """Return *x* (cast to the manifest output dtype) when ``dim`` is + an empty list/tuple and the op's ``_empty_dim_policy`` is + ``"noop"``; return ``None`` otherwise so the caller proceeds with + the normal kernel path. + + Runs the same input validation as ``_prepare_input`` (CUDA / dtype + / ndim) and binds ``_last_roofline_mn`` before short-circuiting, so + the noop path still honors the public forward contract -- bad + inputs raise, and ``eval_roofline()`` works after a noop forward. + """ + if self._empty_dim_policy != "noop": + return None + if not isinstance(self.dim, (list, tuple)) or len(self.dim) != 0: + return None + self._validate_input_tensor(x) + # Bind roofline state. The noop performs no reduction but still + # reads every input element and writes an equal-shape result + # (cast to bool for All/Any, the only ops whose ``_empty_dim_policy`` + # is ``"noop"``; other reduce ops, including ``CountNonzero``, keep + # ``"full"`` and never enter this branch). Model this as a + # degenerate reduction over an axis of length 1: M = numel, N = 1. + # Under the existing per-op-kind + # formulas this yields mem_bytes proportional to numel * elem_bytes + # for the read plus the output term, instead of collapsing to + # zero, which would under-count the actual data-movement cost. + self._last_roofline_mn = (x.numel(), 1) + out_dtype = self._noop_output_dtype() + if out_dtype is None: + return x + return x.to(out_dtype) + def eval_roofline(self) -> tuple[int, int]: if self._last_roofline_mn is None: raise RuntimeError( @@ -249,12 +337,7 @@ def _prepare_input( via masked loads. Otherwise, host-side ``F.pad`` is applied for backward compatibility with kernels that expect ``(M, N_padded)``. """ - if not x.is_cuda: - raise ValueError("x must be a CUDA tensor") - if x.dtype != self.dtype: - raise ValueError(f"Expected x.dtype {self.dtype}, got {x.dtype}") - if x.ndim == 0: - raise ValueError("Input tensor must be at least 1D") + self._validate_input_tensor(x) orig_shape = x.shape @@ -338,17 +421,20 @@ def _reshape_output( class _SimpleReduceOp(_ReduceOpBase): """Base for single-output reduce ops (sum, mean, amin, amax, prod). - Construction: ``op(dtype=..., dim=-1, keepdim=False)``. M and N are - derived from the input tensor at forward time, and kernels are cached - by ``(M, N)`` to avoid rebuilds. + M and N are derived from the input tensor at forward time, and kernels + are cached by ``(M, N)`` to avoid rebuilds. Alignment padding is handled + inside the kernel via masked loads with identity-element fills, so no + host-side ``F.pad`` is needed. - Alignment padding is handled inside the kernel via masked loads with - identity-element fills, so no host-side ``F.pad`` is needed. + The ``dim`` default follows each op's manifest entry: ``sum``, ``mean``, + ``amin``, and ``amax`` default to ``None`` (full reduction); ``prod`` + overrides to ``dim=-1`` and restricts the type to ``int``. Args: dtype: Data type (float32, float16, or bfloat16). - dim: Reduction dimension (default -1). Accepts ``int`` or - ``list[int]`` for multi-dim reduction. + dim: Reduction dimension. Accepts ``int``, ``list[int]``, + ``tuple[int, ...]``, or ``None`` on the base class; subclasses + may narrow this (see ``ProdFwdOp``). keepdim: Whether to retain the reduced dimension as size 1. kernel_map: Optional override for kernel dispatch. tune: Whether to autotune (default False). @@ -386,10 +472,46 @@ class AmaxFwdOp(_SimpleReduceOp): class ProdFwdOp(_SimpleReduceOp): - """Product reduction along dim=-1.""" + """Product reduction. + + Unlike the other simple reduce ops, ``ProdFwdOp`` defaults to + ``dim=-1`` (manifest declares ``default: -1`` for ``prod``). + """ _op_kind = "prod" + def __init__( + self, + *, + dtype: torch.dtype, + dim: int = -1, + keepdim: bool = False, + kernel_map: Optional[Dict[str, Kernel]] = None, + tune: bool = False, + ): + """Construct ProdFwdOp. + + Args: + dtype: Input data type. + dim (int): reduction dimension (default ``-1``). + keepdim: Whether to retain reduced dims as size 1. + kernel_map: Optional override for kernel dispatch. + tune: Whether to autotune (default ``False``). + """ + super().__init__( + dtype=dtype, dim=dim, keepdim=keepdim, + kernel_map=kernel_map, tune=tune, + ) + + def _validate_dim(self) -> None: + # Manifest declares prod.signature.params.dim as int; reject the + # multi-dim and full-reduction overloads inherited from the base. + if not isinstance(self.dim, int) or isinstance(self.dim, bool): + raise TypeError( + f"ProdFwdOp.dim must be int, got " + f"{type(self.dim).__name__}" + ) + # --------------------------------------------------------------------------- # Welford-based ops (std, var, var_mean) @@ -399,7 +521,7 @@ class ProdFwdOp(_SimpleReduceOp): class _WelfordReduceOp(_ReduceOpBase): """Base for Welford-based reduce ops (std, var, var_mean). - Construction: ``op(dtype=..., dim=-1, correction=1, keepdim=False)``. + Construction: ``op(dtype=..., dim=None, correction=1, keepdim=False)``. M and N are derived from the input tensor at forward time, and kernels are cached by ``(M, N)`` to avoid rebuilds. @@ -408,8 +530,9 @@ class _WelfordReduceOp(_ReduceOpBase): Args: dtype: Data type (float32, float16, or bfloat16). - dim: Reduction dimension (default -1). Accepts ``int`` or - ``list[int]`` for multi-dim reduction. + dim: Reduction dimension (default ``None``, i.e. full reduction). + Accepts ``int``, ``list[int]``, or ``tuple[int, ...]`` for + multi-dim reduction. correction: Bessel's correction (default 1). keepdim: Whether to retain the reduced dimension as size 1. kernel_map: Optional override for kernel dispatch. @@ -423,12 +546,24 @@ def __init__( self, *, dtype: torch.dtype, - dim: Union[int, List[int], None] = -1, + dim: Union[int, List[int], Tuple[int, ...], None] = None, correction: int = 1, keepdim: bool = False, kernel_map: Optional[Dict[str, Kernel]] = None, tune: bool = False, ): + """Construct a Welford-based reduce op. + + Args: + dtype: Input data type. + dim: Reduction dimension (default ``None``, i.e. full reduction). + Accepts ``int``, ``list[int]``, ``tuple[int, ...]``, or + ``None``. + correction: Bessel's correction (default 1). + keepdim: Whether to retain reduced dims as size 1. + kernel_map: Optional override for kernel dispatch. + tune: Whether to autotune (default ``False``). + """ self.correction = correction super().__init__( dtype=dtype, dim=dim, keepdim=keepdim,