Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions benchmark/test_binary_pointwise_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,29 @@ def get_tflops(self, op, *args, **kwargs):
return torch.tensor(shape1).prod().item() + torch.tensor(shape2).prod().item()


class ScalarBinaryPointwiseBenchmark(Benchmark):
"""
Base class for benchmarking binary pointwise operations with scalar input.
"""

DEFAULT_METRICS = DEFAULT_METRICS[:] + ["tflops"]

def set_more_shapes(self):
special_shapes_2d = [(1024, 2**i) for i in range(0, 20, 4)]
shapes_3d = [(64, 64, 2**i) for i in range(0, 20, 4)]
return special_shapes_2d + shapes_3d

def get_input_iter(self, cur_dtype) -> Generator:
for shape in self.shapes:
inp1 = 0.001 # Scalar input
inp2 = generate_tensor_input(shape, cur_dtype, self.device)
yield inp1, inp2

def get_tflops(self, op, *args, **kwargs):
shape = list(args[1].shape) # Second argument is the tensor
return torch.tensor(shape).prod().item()


@pytest.mark.parametrize(
"op_name, torch_op, dtypes",
[
Expand Down Expand Up @@ -119,3 +142,24 @@ def test_general_inplace_binary_pointwise_perf(op_name, torch_op, dtypes):
op_name=op_name, torch_op=torch_op, dtypes=dtypes, is_inplace=True
)
bench.run()


@pytest.mark.parametrize(
"op_name, torch_op, dtypes",
[
pytest.param(
name,
op,
dtype,
marks=getattr(pytest.mark, name, None),
)
for name, op, dtype in [
("pow", lambda a, b: torch.pow(a, b), FLOAT_DTYPES),
]
],
)
def test_scalar_binary_pointwise_perf(op_name, torch_op, dtypes):
bench = ScalarBinaryPointwiseBenchmark(
op_name=op_name, torch_op=torch_op, dtypes=dtypes
)
bench.run()
2 changes: 2 additions & 0 deletions benchmark/test_tensor_constructor_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ def test_tensor_constructor_benchmark(op_name, torch_op, input_fn):
tensor_constructor_inplace_operations = [
# tensor constructor with given value
("fill_", torch.fill_, fill_input_fn),
("fill_scalar_", torch.ops.aten.fill_.Scalar, fill_input_fn),
# ("fill_scalar_", flag_gems.ops.fill.fill_scalar_, fill_input_fn),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the FlagGems benchmark for fill_scalar_ commented out?

("masked_fill_", lambda a, b, c: a.masked_fill_(b, c), masked_fill_input_fn),
]

Expand Down
8 changes: 8 additions & 0 deletions src/flag_gems/runtime/backend/_nvidia/hopper/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
import triton

if triton.__version__ >= "3.4":
from .fill import ( # noqa: F401
fill_scalar,
fill_scalar_,
fill_scalar_out,
fill_tensor,
fill_tensor_,
fill_tensor_out,
)
from .mm import mm, mm_out # noqa: F401

__all__ = ["*"]
121 changes: 121 additions & 0 deletions src/flag_gems/runtime/backend/_nvidia/hopper/ops/fill.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import logging

import torch
import triton
import triton.language as tl

from flag_gems.runtime import torch_device_fn

logger = logging.getLogger(__name__)


@triton.jit
def fill_scalar_kernel(
out_ptr,
value_scalar,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements

# Load a dummy value to infer the dtype of out_ptr
dummy = tl.load(out_ptr + offsets, mask=mask, other=0)
fill_val = tl.full([BLOCK_SIZE], value_scalar, dtype=dummy.dtype)
tl.store(out_ptr + offsets, fill_val, mask=mask)


@triton.jit
def fill_tensor_kernel(
out_ptr,
value_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements

val = tl.load(value_ptr)
tl.store(out_ptr + offsets, val, mask=mask)


def fill_scalar(input, value):
logger.debug("GEMS FILL_SCALAR HOPPER")
out = torch.empty_like(input)
n_elements = out.numel()
grid = (triton.cdiv(n_elements, 1024),)
with torch_device_fn.device(input.device):
fill_scalar_kernel[grid](out, value, n_elements, BLOCK_SIZE=1024)
return out


def fill_scalar_out(input, value, *, out=None):
logger.debug("GEMS FILL_SCALAR_OUT HOPPER")
if out is None:
return fill_scalar(input, value)
n_elements = out.numel()
grid = (triton.cdiv(n_elements, 1024),)
with torch_device_fn.device(input.device):
fill_scalar_kernel[grid](out, value, n_elements, BLOCK_SIZE=1024)
return out


def fill_tensor(input, value):
if not value.is_cuda:
return fill_scalar(input, value.item())
logger.debug("GEMS FILL_TENSOR HOPPER")
if value.ndim != 0:
raise RuntimeError(
f"fill only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions."
)
out = torch.empty_like(input)
n_elements = out.numel()
grid = (triton.cdiv(n_elements, 1024),)
with torch_device_fn.device(input.device):
fill_tensor_kernel[grid](out, value, n_elements, BLOCK_SIZE=1024)
return out


def fill_tensor_out(input, value, *, out=None):
logger.debug("GEMS FILL_TENSOR_OUT HOPPER")
if out is None:
return fill_tensor(input, value)
if not value.is_cuda:
return fill_scalar_out(input, value.item(), out=out)
if value.ndim != 0:
raise RuntimeError(
f"fill only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions."
)
n_elements = out.numel()
grid = (triton.cdiv(n_elements, 1024),)
with torch_device_fn.device(input.device):
fill_tensor_kernel[grid](out, value, n_elements, BLOCK_SIZE=1024)
return out


def fill_tensor_(self, value):
if not value.is_cuda:
return fill_scalar_(self, value.item())
logger.debug("GEMS FILL_TENSOR_ HOPPER")
if value.ndim != 0:
raise RuntimeError(
f"fill only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions."
)
n_elements = self.numel()
grid = (triton.cdiv(n_elements, 1024),)
with torch_device_fn.device(self.device):
fill_tensor_kernel[grid](self, value, n_elements, BLOCK_SIZE=1024)
return self


def fill_scalar_(self, value):
logger.debug("GEMS FILL_SCALAR_ HOPPER")
n_elements = self.numel()
grid = (triton.cdiv(n_elements, 1024),)
with torch_device_fn.device(self.device):
fill_scalar_kernel[grid](self, value, n_elements, BLOCK_SIZE=1024)
return self
Loading