diff --git a/benchmark/test_binary_pointwise_perf.py b/benchmark/test_binary_pointwise_perf.py index 632d3d795c..a295eeadd9 100644 --- a/benchmark/test_binary_pointwise_perf.py +++ b/benchmark/test_binary_pointwise_perf.py @@ -37,6 +37,29 @@ def get_tflops(self, op, *args, **kwargs): return torch.tensor(shape1).prod().item() + torch.tensor(shape2).prod().item() +class ScalarBinaryPointwiseBenchmark(Benchmark): + """ + Base class for benchmarking binary pointwise operations with scalar input. + """ + + DEFAULT_METRICS = DEFAULT_METRICS[:] + ["tflops"] + + def set_more_shapes(self): + special_shapes_2d = [(1024, 2**i) for i in range(0, 20, 4)] + shapes_3d = [(64, 64, 2**i) for i in range(0, 20, 4)] + return special_shapes_2d + shapes_3d + + def get_input_iter(self, cur_dtype) -> Generator: + for shape in self.shapes: + inp1 = 0.001 # Scalar input + inp2 = generate_tensor_input(shape, cur_dtype, self.device) + yield inp1, inp2 + + def get_tflops(self, op, *args, **kwargs): + shape = list(args[1].shape) # Second argument is the tensor + return torch.tensor(shape).prod().item() + + @pytest.mark.parametrize( "op_name, torch_op, dtypes", [ @@ -119,3 +142,24 @@ def test_general_inplace_binary_pointwise_perf(op_name, torch_op, dtypes): op_name=op_name, torch_op=torch_op, dtypes=dtypes, is_inplace=True ) bench.run() + + +@pytest.mark.parametrize( + "op_name, torch_op, dtypes", + [ + pytest.param( + name, + op, + dtype, + marks=getattr(pytest.mark, name, None), + ) + for name, op, dtype in [ + ("pow", lambda a, b: torch.pow(a, b), FLOAT_DTYPES), + ] + ], +) +def test_scalar_binary_pointwise_perf(op_name, torch_op, dtypes): + bench = ScalarBinaryPointwiseBenchmark( + op_name=op_name, torch_op=torch_op, dtypes=dtypes + ) + bench.run() diff --git a/benchmark/test_tensor_constructor_perf.py b/benchmark/test_tensor_constructor_perf.py index ba59e6b011..c10008257d 100644 --- a/benchmark/test_tensor_constructor_perf.py +++ b/benchmark/test_tensor_constructor_perf.py @@ -190,6 +190,8 @@ def test_tensor_constructor_benchmark(op_name, torch_op, input_fn): tensor_constructor_inplace_operations = [ # tensor constructor with given value ("fill_", torch.fill_, fill_input_fn), + ("fill_scalar_", torch.ops.aten.fill_.Scalar, fill_input_fn), + # ("fill_scalar_", flag_gems.ops.fill.fill_scalar_, fill_input_fn), ("masked_fill_", lambda a, b, c: a.masked_fill_(b, c), masked_fill_input_fn), ] diff --git a/src/flag_gems/runtime/backend/_nvidia/hopper/ops/__init__.py b/src/flag_gems/runtime/backend/_nvidia/hopper/ops/__init__.py index 6cbe8cf151..351cb742b6 100644 --- a/src/flag_gems/runtime/backend/_nvidia/hopper/ops/__init__.py +++ b/src/flag_gems/runtime/backend/_nvidia/hopper/ops/__init__.py @@ -1,6 +1,14 @@ import triton if triton.__version__ >= "3.4": + from .fill import ( # noqa: F401 + fill_scalar, + fill_scalar_, + fill_scalar_out, + fill_tensor, + fill_tensor_, + fill_tensor_out, + ) from .mm import mm, mm_out # noqa: F401 __all__ = ["*"] diff --git a/src/flag_gems/runtime/backend/_nvidia/hopper/ops/fill.py b/src/flag_gems/runtime/backend/_nvidia/hopper/ops/fill.py new file mode 100644 index 0000000000..d6eb80915b --- /dev/null +++ b/src/flag_gems/runtime/backend/_nvidia/hopper/ops/fill.py @@ -0,0 +1,121 @@ +import logging + +import torch +import triton +import triton.language as tl + +from flag_gems.runtime import torch_device_fn + +logger = logging.getLogger(__name__) + + +@triton.jit +def fill_scalar_kernel( + out_ptr, + value_scalar, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + # Load a dummy value to infer the dtype of out_ptr + dummy = tl.load(out_ptr + offsets, mask=mask, other=0) + fill_val = tl.full([BLOCK_SIZE], value_scalar, dtype=dummy.dtype) + tl.store(out_ptr + offsets, fill_val, mask=mask) + + +@triton.jit +def fill_tensor_kernel( + out_ptr, + value_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + val = tl.load(value_ptr) + tl.store(out_ptr + offsets, val, mask=mask) + + +def fill_scalar(input, value): + logger.debug("GEMS FILL_SCALAR HOPPER") + out = torch.empty_like(input) + n_elements = out.numel() + grid = (triton.cdiv(n_elements, 1024),) + with torch_device_fn.device(input.device): + fill_scalar_kernel[grid](out, value, n_elements, BLOCK_SIZE=1024) + return out + + +def fill_scalar_out(input, value, *, out=None): + logger.debug("GEMS FILL_SCALAR_OUT HOPPER") + if out is None: + return fill_scalar(input, value) + n_elements = out.numel() + grid = (triton.cdiv(n_elements, 1024),) + with torch_device_fn.device(input.device): + fill_scalar_kernel[grid](out, value, n_elements, BLOCK_SIZE=1024) + return out + + +def fill_tensor(input, value): + if not value.is_cuda: + return fill_scalar(input, value.item()) + logger.debug("GEMS FILL_TENSOR HOPPER") + if value.ndim != 0: + raise RuntimeError( + f"fill only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions." + ) + out = torch.empty_like(input) + n_elements = out.numel() + grid = (triton.cdiv(n_elements, 1024),) + with torch_device_fn.device(input.device): + fill_tensor_kernel[grid](out, value, n_elements, BLOCK_SIZE=1024) + return out + + +def fill_tensor_out(input, value, *, out=None): + logger.debug("GEMS FILL_TENSOR_OUT HOPPER") + if out is None: + return fill_tensor(input, value) + if not value.is_cuda: + return fill_scalar_out(input, value.item(), out=out) + if value.ndim != 0: + raise RuntimeError( + f"fill only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions." + ) + n_elements = out.numel() + grid = (triton.cdiv(n_elements, 1024),) + with torch_device_fn.device(input.device): + fill_tensor_kernel[grid](out, value, n_elements, BLOCK_SIZE=1024) + return out + + +def fill_tensor_(self, value): + if not value.is_cuda: + return fill_scalar_(self, value.item()) + logger.debug("GEMS FILL_TENSOR_ HOPPER") + if value.ndim != 0: + raise RuntimeError( + f"fill only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions." + ) + n_elements = self.numel() + grid = (triton.cdiv(n_elements, 1024),) + with torch_device_fn.device(self.device): + fill_tensor_kernel[grid](self, value, n_elements, BLOCK_SIZE=1024) + return self + + +def fill_scalar_(self, value): + logger.debug("GEMS FILL_SCALAR_ HOPPER") + n_elements = self.numel() + grid = (triton.cdiv(n_elements, 1024),) + with torch_device_fn.device(self.device): + fill_scalar_kernel[grid](self, value, n_elements, BLOCK_SIZE=1024) + return self