Skip to content
Open
21 changes: 21 additions & 0 deletions src/flag_gems/runtime/backend/_iluvatar/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,27 @@
from .add import add, add_
from .div import div_mode, div_mode_
from .exponential_ import exponential_
from .mul import mul
from .ones import ones
from .pow import pow_scalar, pow_scalar_
from .repeat import repeat
from .sub import sub, sub_
from .true_divide import true_divide, true_divide_, true_divide_out

__all__ = [
"add",
"add_",
"div_mode",
"div_mode_",
"exponential_",
"mul",
"ones",
"pow_scalar",
"pow_scalar_",
"repeat",
"sub",
"sub_",
"true_divide",
"true_divide_out",
"true_divide_",
]
57 changes: 57 additions & 0 deletions src/flag_gems/runtime/backend/_iluvatar/ops/add.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import logging

import torch
import triton

from flag_gems.utils import pointwise_dynamic

logger = logging.getLogger(__name__)


@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")])
@triton.jit
def add_func(x, y, alpha):
return x + y * alpha


@pointwise_dynamic(
is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")]
)
@triton.jit
def add_func_tensor_scalar(x, y, alpha):
return x + y * alpha


@pointwise_dynamic(
is_tensor=[False, True, False], promotion_methods=[(0, 1, "DEFAULT")]
)
@triton.jit
def add_func_scalar_tensor(x, y, alpha):
return x + y * alpha


def add(A, B, *, alpha=1):
logger.debug("GEMS ILUVATAR ADD")
if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
if B.device != A.device:
B = B.to(A.device)
return add_func(A, B, alpha)
elif isinstance(A, torch.Tensor):
return add_func_tensor_scalar(A, B, alpha)
elif isinstance(B, torch.Tensor):
return add_func_scalar_tensor(A, B, alpha)
else:
return torch.tensor(A + B * alpha)


def add_(A, B, *, alpha=1):
logger.debug("GEMS ILUVATAR ADD_")
if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
if B.device != A.device:
B = B.to(A.device)
return add_func(A, B, alpha, out0=A)
elif isinstance(A, torch.Tensor):
return add_func_tensor_scalar(A, B, alpha, out0=A)
# Note: scalar_tensor case not supported for in-place
else:
raise ValueError("Unreachable.")
205 changes: 205 additions & 0 deletions src/flag_gems/runtime/backend/_iluvatar/ops/exponential_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
import logging

import torch
import triton
import triton.language as tl

from flag_gems.runtime import torch_device_fn
from flag_gems.utils import libentry, libtuner
from flag_gems.utils.random_utils import (
philox_backend_seed_offset,
uint_to_uniform_float,
)

logger = logging.getLogger(__name__)


@triton.jit
def safe_fast_log_f32(x):
min_normal = (x * 0.0 + 1.17549435e-38).to(tl.float32)
max_u = x * 0.0 + 0.99999994
x = tl.minimum(tl.maximum(x, min_normal), max_u)
bits = x.to(tl.int32, bitcast=True)
exponent = (bits >> 23) - 127
mantissa = (bits & 0x7FFFFF).to(tl.float32) * (1.0 / 8388608.0) + 1.0
m1 = mantissa - 1.0
return (
m1 * (1.0 + m1 * (-0.5 + m1 * (0.3333333333 - m1 * 0.25)))
+ exponent.to(tl.float32) * 0.6931471805599453
)


@triton.jit
def safe_fast_log_f64(x):
min_normal = x * 0.0 + 2.2250738585072014e-308
max_u = x * 0.0 + (1.0 - 2.220446049250313e-16)
x = tl.minimum(tl.maximum(x, min_normal), max_u)
bits = x.to(tl.int64, bitcast=True)
exponent = (bits >> 52) - 1023
mantissa = (bits & 0x000FFFFFFFFFFFFF).to(tl.float64) * (
1.0 / 4503599627370496.0
) + 1.0
m1 = mantissa - 1.0
return (
m1 * (1.0 + m1 * (-0.5 + m1 * (0.3333333333333333 - m1 * 0.25)))
+ exponent.to(tl.float64) * 0.6931471805599453
)


@triton.jit
def paste_u64(hi: tl.uint32, lo: tl.uint32):
return (hi.to(tl.uint64) << 32) | lo.to(tl.uint64)


@triton.jit
def transform_exponential_f32_precise(u, inv_lambd, eps_minus):
log = tl.where(u >= 1.0 + eps_minus, eps_minus, tl.math.log(u))
return -inv_lambd * log


@triton.jit
def transform_exponential_f32_fast(u, inv_lambd, eps_minus):
log = tl.where(u >= 1.0 + eps_minus, eps_minus, safe_fast_log_f32(u))
return -inv_lambd * log


# Iluvatar uses the precise version for numerical stability
transform_exponential_f32 = transform_exponential_f32_precise


@triton.jit
def transform_exponential_f64(u, inv_lambd, eps_minus):
log = tl.where(u >= 1.0 + eps_minus, eps_minus, safe_fast_log_f64(u))
return -inv_lambd * log


@libentry()
@libtuner(
configs=[
triton.Config({"BLOCK": 64}, num_warps=2, num_stages=2),
triton.Config({"BLOCK": 128}, num_warps=2, num_stages=2),
triton.Config({"BLOCK": 256}, num_warps=4, num_stages=2),
triton.Config({"BLOCK": 512}, num_warps=4, num_stages=3),
triton.Config({"BLOCK": 1024}, num_warps=4, num_stages=2),
triton.Config({"BLOCK": 2048}, num_warps=8, num_stages=2),
],
key=["N"],
)
@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"])
def fused_exponential_kernel_f32(
out_ptr, N, inv_lambd, eps_minus, philox_seed, philox_offset, BLOCK: tl.constexpr
):
philox_seed = philox_seed.to(tl.int64)
philox_offset = philox_offset.to(tl.int64)
c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)

pid = tl.program_id(0)
i = pid * BLOCK + tl.arange(0, BLOCK)
c0 += i
z = c0 * 0
r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, z, z)

y0 = transform_exponential_f32(uint_to_uniform_float(r0), inv_lambd, eps_minus)
y1 = transform_exponential_f32(uint_to_uniform_float(r1), inv_lambd, eps_minus)
y2 = transform_exponential_f32(uint_to_uniform_float(r2), inv_lambd, eps_minus)
y3 = transform_exponential_f32(uint_to_uniform_float(r3), inv_lambd, eps_minus)

start = pid.to(tl.uint64) * BLOCK * 4
off0 = start + tl.arange(0, BLOCK)
off1 = off0 + BLOCK
off2 = off1 + BLOCK
off3 = off2 + BLOCK

tl.store(out_ptr + off0, y0, mask=off0 < N)
tl.store(out_ptr + off1, y1, mask=off1 < N)
tl.store(out_ptr + off2, y2, mask=off2 < N)
tl.store(out_ptr + off3, y3, mask=off3 < N)


@libentry()
@libtuner(
configs=[
triton.Config({"BLOCK": 64}, num_warps=2, num_stages=2),
triton.Config({"BLOCK": 128}, num_warps=2, num_stages=2),
triton.Config({"BLOCK": 256}, num_warps=4, num_stages=2),
triton.Config({"BLOCK": 512}, num_warps=4, num_stages=3),
triton.Config({"BLOCK": 1024}, num_warps=4, num_stages=2),
triton.Config({"BLOCK": 2048}, num_warps=8, num_stages=2),
],
key=["N"],
)
@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"])
def fused_exponential_kernel_f64(
out_ptr, N, inv_lambd, eps_minus, philox_seed, philox_offset, BLOCK: tl.constexpr
):
philox_seed = philox_seed.to(tl.int64)
philox_offset = philox_offset.to(tl.int64)
c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)

pid = tl.program_id(0)
i = pid * BLOCK + tl.arange(0, BLOCK)
c0 += i
z = c0 * 0
r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, z, z)

u0 = uint_to_uniform_float(paste_u64(r0, r2))
u1 = uint_to_uniform_float(paste_u64(r1, r3))

y0 = transform_exponential_f64(u0, inv_lambd, eps_minus)
y1 = transform_exponential_f64(u1, inv_lambd, eps_minus)

start = pid.to(tl.uint64) * BLOCK * 2
off0 = start + tl.arange(0, BLOCK)
off1 = off0 + BLOCK

tl.store(out_ptr + off0, y0, mask=off0 < N)
tl.store(out_ptr + off1, y1, mask=off1 < N)


def exponential_(x, lambd: float = 1.0, *, generator=None):
logger.debug("GEMS_ILUVATAR EXPONENTIAL_")

dtype = x.dtype
device = x.device
inplace = x.is_contiguous()
assert dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64)

N = x.numel()

# Handle empty tensor
if N == 0:
return x

inv_lambd = 1.0 / lambd
eps_minus = -0.5 * torch.finfo(dtype).eps

out = x if inplace else torch.empty_like(x)

if dtype is torch.float64:
UNROLL = 2
grid = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),)
increment = triton.cdiv(N, UNROLL)
philox_seed, philox_offset = philox_backend_seed_offset(
increment, generator=generator
)
with torch_device_fn.device(device):
fused_exponential_kernel_f64[grid](
out, N, inv_lambd, eps_minus, philox_seed, philox_offset
)
else:
UNROLL = 4
grid = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),)
increment = triton.cdiv(N, UNROLL)
philox_seed, philox_offset = philox_backend_seed_offset(
increment, generator=generator
)
with torch_device_fn.device(device):
fused_exponential_kernel_f32[grid](
out, N, inv_lambd, eps_minus, philox_seed, philox_offset
)

if not inplace:
x.copy_(out)
return x
47 changes: 47 additions & 0 deletions src/flag_gems/runtime/backend/_iluvatar/ops/pow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import logging

import triton
import triton.language as tl

from flag_gems.utils import pointwise_dynamic, tl_extra_shim

_pow = tl_extra_shim.pow
logger = logging.getLogger(__name__)


@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "BOOL_TO_LONG")])
@triton.jit
def pow_func_scalar_tensor(x, exponent):
return _pow(x.to(tl.float32), exponent.to(tl.float32))


def pow_scalar(A, exponent):
"""
Computes base^exponent where base is a scalar and exponent is a tensor.

Uses FlagGems standard pointwise_dynamic for hardware compatibility.

Args:
A: Scalar base value
exponent: Exponent tensor

Returns:
Output tensor with same shape as exponent
"""
logger.debug("GEMS_ILUVATAR POW_SCALAR")
return pow_func_scalar_tensor(A, exponent)


def pow_scalar_(A, exponent):
"""
In-place version of pow_scalar.

Args:
A: Scalar base value
exponent: Exponent tensor (modified in-place)

Returns:
The modified exponent tensor
"""
logger.debug("GEMS_ILUVATAR POW_SCALAR_")
return pow_func_scalar_tensor(A, exponent, out0=exponent)
Loading
Loading