diff --git a/src/flag_gems/runtime/backend/_nvidia/hopper/ops/__init__.py b/src/flag_gems/runtime/backend/_nvidia/hopper/ops/__init__.py index 6cbe8cf151..9fa55f8e29 100644 --- a/src/flag_gems/runtime/backend/_nvidia/hopper/ops/__init__.py +++ b/src/flag_gems/runtime/backend/_nvidia/hopper/ops/__init__.py @@ -2,5 +2,6 @@ if triton.__version__ >= "3.4": from .mm import mm, mm_out # noqa: F401 + from .sqrt import sqrt, sqrt_ # noqa: F401 __all__ = ["*"] diff --git a/src/flag_gems/runtime/backend/_nvidia/hopper/ops/sqrt.py b/src/flag_gems/runtime/backend/_nvidia/hopper/ops/sqrt.py new file mode 100644 index 0000000000..ec526e6b54 --- /dev/null +++ b/src/flag_gems/runtime/backend/_nvidia/hopper/ops/sqrt.py @@ -0,0 +1,52 @@ +import logging + +import torch +import triton +import triton.language as tl + +logger = logging.getLogger("flag_gems.runtime.backend._nvidia.hopper.ops.sqrt") + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 512}, num_stages=4, num_warps=1), + triton.Config({"BLOCK_SIZE": 1024}, num_stages=4, num_warps=1), + triton.Config({"BLOCK_SIZE": 2048}, num_stages=4, num_warps=1), + ], + key=["n_elements"], +) +@triton.jit +def sqrt_kernel( + input_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(input_ptr + offsets, mask=mask) + x_fp32 = x.to(tl.float32) + output = tl.sqrt(x_fp32) + output = output.to(output_ptr.dtype.element_ty) + tl.store(output_ptr + offsets, output, mask=mask) + + +def sqrt(A): + logger.debug("GEMS SQRT HOPPER") + output = torch.empty_like(A) + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + sqrt_kernel[grid](A, output, n_elements) + return output + + +def sqrt_(A): + logger.debug("GEMS SQRT_ HOPPER") + output = torch.empty_like(A) + n_elements = A.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + sqrt_kernel[grid](A, output, n_elements) + A.copy_(output) + return A