Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/flag_gems/runtime/backend/_iluvatar/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,25 @@
from .add import add, add_
from .div import div_mode, div_mode_
from .exponential_ import exponential_
from .mul import mul
from .pow import pow_scalar, pow_scalar_
from .sub import sub, sub_
from .true_divide import true_divide, true_divide_, true_divide_out
from .view_as_complex import view_as_complex

__all__ = [
"add",
"add_",
"div_mode",
"div_mode_",
"exponential_",
"mul",
"pow_scalar",
"pow_scalar_",
"sub",
"sub_",
"true_divide",
"true_divide_out",
"true_divide_",
"view_as_complex",
]
57 changes: 57 additions & 0 deletions src/flag_gems/runtime/backend/_iluvatar/ops/add.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import logging

import torch
import triton

from flag_gems.utils import pointwise_dynamic

logger = logging.getLogger(__name__)


@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")])
@triton.jit
def add_func(x, y, alpha):
return x + y * alpha


@pointwise_dynamic(
is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")]
)
@triton.jit
def add_func_tensor_scalar(x, y, alpha):
return x + y * alpha


@pointwise_dynamic(
is_tensor=[False, True, False], promotion_methods=[(0, 1, "DEFAULT")]
)
@triton.jit
def add_func_scalar_tensor(x, y, alpha):
return x + y * alpha


def add(A, B, *, alpha=1):
logger.debug("GEMS ILUVATAR ADD")
if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
if B.device != A.device:
B = B.to(A.device)
return add_func(A, B, alpha)
elif isinstance(A, torch.Tensor):
return add_func_tensor_scalar(A, B, alpha)
elif isinstance(B, torch.Tensor):
return add_func_scalar_tensor(A, B, alpha)
else:
return torch.tensor(A + B * alpha)


def add_(A, B, *, alpha=1):
logger.debug("GEMS ILUVATAR ADD_")
if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
if B.device != A.device:
B = B.to(A.device)
return add_func(A, B, alpha, out0=A)
elif isinstance(A, torch.Tensor):
return add_func_tensor_scalar(A, B, alpha, out0=A)
# Note: scalar_tensor case not supported for in-place
else:
raise ValueError("Unreachable.")
205 changes: 205 additions & 0 deletions src/flag_gems/runtime/backend/_iluvatar/ops/exponential_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
import logging

import torch
import triton
import triton.language as tl

from flag_gems.runtime import torch_device_fn
from flag_gems.utils import libentry, libtuner
from flag_gems.utils.random_utils import (
philox_backend_seed_offset,
uint_to_uniform_float,
)

logger = logging.getLogger(__name__)


@triton.jit
def safe_fast_log_f32(x):
min_normal = (x * 0.0 + 1.17549435e-38).to(tl.float32)
max_u = x * 0.0 + 0.99999994
x = tl.minimum(tl.maximum(x, min_normal), max_u)
bits = x.to(tl.int32, bitcast=True)
exponent = (bits >> 23) - 127
mantissa = (bits & 0x7FFFFF).to(tl.float32) * (1.0 / 8388608.0) + 1.0
m1 = mantissa - 1.0
return (
m1 * (1.0 + m1 * (-0.5 + m1 * (0.3333333333 - m1 * 0.25)))
+ exponent.to(tl.float32) * 0.6931471805599453
)


@triton.jit
def safe_fast_log_f64(x):
min_normal = x * 0.0 + 2.2250738585072014e-308
max_u = x * 0.0 + (1.0 - 2.220446049250313e-16)
x = tl.minimum(tl.maximum(x, min_normal), max_u)
bits = x.to(tl.int64, bitcast=True)
exponent = (bits >> 52) - 1023
mantissa = (bits & 0x000FFFFFFFFFFFFF).to(tl.float64) * (
1.0 / 4503599627370496.0
) + 1.0
m1 = mantissa - 1.0
return (
m1 * (1.0 + m1 * (-0.5 + m1 * (0.3333333333333333 - m1 * 0.25)))
+ exponent.to(tl.float64) * 0.6931471805599453
)


@triton.jit
def paste_u64(hi: tl.uint32, lo: tl.uint32):
return (hi.to(tl.uint64) << 32) | lo.to(tl.uint64)


@triton.jit
def transform_exponential_f32_precise(u, inv_lambd, eps_minus):
log = tl.where(u >= 1.0 + eps_minus, eps_minus, tl.math.log(u))
return -inv_lambd * log


@triton.jit
def transform_exponential_f32_fast(u, inv_lambd, eps_minus):
log = tl.where(u >= 1.0 + eps_minus, eps_minus, safe_fast_log_f32(u))
return -inv_lambd * log


# Iluvatar uses the precise version for numerical stability
transform_exponential_f32 = transform_exponential_f32_precise


@triton.jit
def transform_exponential_f64(u, inv_lambd, eps_minus):
log = tl.where(u >= 1.0 + eps_minus, eps_minus, safe_fast_log_f64(u))
return -inv_lambd * log


@libentry()
@libtuner(
configs=[
triton.Config({"BLOCK": 64}, num_warps=2, num_stages=2),
triton.Config({"BLOCK": 128}, num_warps=2, num_stages=2),
triton.Config({"BLOCK": 256}, num_warps=4, num_stages=2),
triton.Config({"BLOCK": 512}, num_warps=4, num_stages=3),
triton.Config({"BLOCK": 1024}, num_warps=4, num_stages=2),
triton.Config({"BLOCK": 2048}, num_warps=8, num_stages=2),
],
key=["N"],
)
@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"])
def fused_exponential_kernel_f32(
out_ptr, N, inv_lambd, eps_minus, philox_seed, philox_offset, BLOCK: tl.constexpr
):
philox_seed = philox_seed.to(tl.int64)
philox_offset = philox_offset.to(tl.int64)
c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)

pid = tl.program_id(0)
i = pid * BLOCK + tl.arange(0, BLOCK)
c0 += i
z = c0 * 0
r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, z, z)

y0 = transform_exponential_f32(uint_to_uniform_float(r0), inv_lambd, eps_minus)
y1 = transform_exponential_f32(uint_to_uniform_float(r1), inv_lambd, eps_minus)
y2 = transform_exponential_f32(uint_to_uniform_float(r2), inv_lambd, eps_minus)
y3 = transform_exponential_f32(uint_to_uniform_float(r3), inv_lambd, eps_minus)

start = pid.to(tl.uint64) * BLOCK * 4
off0 = start + tl.arange(0, BLOCK)
off1 = off0 + BLOCK
off2 = off1 + BLOCK
off3 = off2 + BLOCK

tl.store(out_ptr + off0, y0, mask=off0 < N)
tl.store(out_ptr + off1, y1, mask=off1 < N)
tl.store(out_ptr + off2, y2, mask=off2 < N)
tl.store(out_ptr + off3, y3, mask=off3 < N)


@libentry()
@libtuner(
configs=[
triton.Config({"BLOCK": 64}, num_warps=2, num_stages=2),
triton.Config({"BLOCK": 128}, num_warps=2, num_stages=2),
triton.Config({"BLOCK": 256}, num_warps=4, num_stages=2),
triton.Config({"BLOCK": 512}, num_warps=4, num_stages=3),
triton.Config({"BLOCK": 1024}, num_warps=4, num_stages=2),
triton.Config({"BLOCK": 2048}, num_warps=8, num_stages=2),
],
key=["N"],
)
@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"])
def fused_exponential_kernel_f64(
out_ptr, N, inv_lambd, eps_minus, philox_seed, philox_offset, BLOCK: tl.constexpr
):
philox_seed = philox_seed.to(tl.int64)
philox_offset = philox_offset.to(tl.int64)
c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)

pid = tl.program_id(0)
i = pid * BLOCK + tl.arange(0, BLOCK)
c0 += i
z = c0 * 0
r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, z, z)

u0 = uint_to_uniform_float(paste_u64(r0, r2))
u1 = uint_to_uniform_float(paste_u64(r1, r3))

y0 = transform_exponential_f64(u0, inv_lambd, eps_minus)
y1 = transform_exponential_f64(u1, inv_lambd, eps_minus)

start = pid.to(tl.uint64) * BLOCK * 2
off0 = start + tl.arange(0, BLOCK)
off1 = off0 + BLOCK

tl.store(out_ptr + off0, y0, mask=off0 < N)
tl.store(out_ptr + off1, y1, mask=off1 < N)


def exponential_(x, lambd: float = 1.0, *, generator=None):
logger.debug("GEMS_ILUVATAR EXPONENTIAL_")

dtype = x.dtype
device = x.device
inplace = x.is_contiguous()
assert dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64)

N = x.numel()

# Handle empty tensor
if N == 0:
return x

inv_lambd = 1.0 / lambd
eps_minus = -0.5 * torch.finfo(dtype).eps

out = x if inplace else torch.empty_like(x)

if dtype is torch.float64:
UNROLL = 2
grid = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),)
increment = triton.cdiv(N, UNROLL)
philox_seed, philox_offset = philox_backend_seed_offset(
increment, generator=generator
)
with torch_device_fn.device(device):
fused_exponential_kernel_f64[grid](
out, N, inv_lambd, eps_minus, philox_seed, philox_offset
)
else:
UNROLL = 4
grid = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),)
increment = triton.cdiv(N, UNROLL)
philox_seed, philox_offset = philox_backend_seed_offset(
increment, generator=generator
)
with torch_device_fn.device(device):
fused_exponential_kernel_f32[grid](
out, N, inv_lambd, eps_minus, philox_seed, philox_offset
)

if not inplace:
x.copy_(out)
return x
47 changes: 47 additions & 0 deletions src/flag_gems/runtime/backend/_iluvatar/ops/pow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import logging

import triton
import triton.language as tl

from flag_gems.utils import pointwise_dynamic, tl_extra_shim

_pow = tl_extra_shim.pow
logger = logging.getLogger(__name__)


@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "BOOL_TO_LONG")])
@triton.jit
def pow_func_scalar_tensor(x, exponent):
return _pow(x.to(tl.float32), exponent.to(tl.float32))


def pow_scalar(A, exponent):
"""
Computes base^exponent where base is a scalar and exponent is a tensor.

Uses FlagGems standard pointwise_dynamic for hardware compatibility.

Args:
A: Scalar base value
exponent: Exponent tensor

Returns:
Output tensor with same shape as exponent
"""
logger.debug("GEMS_ILUVATAR POW_SCALAR")
return pow_func_scalar_tensor(A, exponent)


def pow_scalar_(A, exponent):
"""
In-place version of pow_scalar.

Args:
A: Scalar base value
exponent: Exponent tensor (modified in-place)

Returns:
The modified exponent tensor
"""
logger.debug("GEMS_ILUVATAR POW_SCALAR_")
return pow_func_scalar_tensor(A, exponent, out0=exponent)
Loading
Loading