Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/flag_gems/runtime/backend/_iluvatar/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
from .div import div_mode, div_mode_
from .exponential_ import exponential_
from .pow import pow_scalar, pow_scalar_
from .true_divide import true_divide, true_divide_, true_divide_out

__all__ = [
"div_mode",
"div_mode_",
"exponential_",
"pow_scalar",
"pow_scalar_",
"true_divide",
"true_divide_out",
"true_divide_",
]
205 changes: 205 additions & 0 deletions src/flag_gems/runtime/backend/_iluvatar/ops/exponential_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
import logging

import torch
import triton
import triton.language as tl

from flag_gems.runtime import torch_device_fn
from flag_gems.utils import libentry, libtuner
from flag_gems.utils.random_utils import (
philox_backend_seed_offset,
uint_to_uniform_float,
)

logger = logging.getLogger(__name__)


@triton.jit
def safe_fast_log_f32(x):
min_normal = (x * 0.0 + 1.17549435e-38).to(tl.float32)
max_u = x * 0.0 + 0.99999994
x = tl.minimum(tl.maximum(x, min_normal), max_u)
bits = x.to(tl.int32, bitcast=True)
exponent = (bits >> 23) - 127
mantissa = (bits & 0x7FFFFF).to(tl.float32) * (1.0 / 8388608.0) + 1.0
m1 = mantissa - 1.0
return (
m1 * (1.0 + m1 * (-0.5 + m1 * (0.3333333333 - m1 * 0.25)))
+ exponent.to(tl.float32) * 0.6931471805599453
)


@triton.jit
def safe_fast_log_f64(x):
min_normal = x * 0.0 + 2.2250738585072014e-308
max_u = x * 0.0 + (1.0 - 2.220446049250313e-16)
x = tl.minimum(tl.maximum(x, min_normal), max_u)
bits = x.to(tl.int64, bitcast=True)
exponent = (bits >> 52) - 1023
mantissa = (bits & 0x000FFFFFFFFFFFFF).to(tl.float64) * (
1.0 / 4503599627370496.0
) + 1.0
m1 = mantissa - 1.0
return (
m1 * (1.0 + m1 * (-0.5 + m1 * (0.3333333333333333 - m1 * 0.25)))
+ exponent.to(tl.float64) * 0.6931471805599453
)


@triton.jit
def paste_u64(hi: tl.uint32, lo: tl.uint32):
return (hi.to(tl.uint64) << 32) | lo.to(tl.uint64)


@triton.jit
def transform_exponential_f32_precise(u, inv_lambd, eps_minus):
log = tl.where(u >= 1.0 + eps_minus, eps_minus, tl.math.log(u))
return -inv_lambd * log


@triton.jit
def transform_exponential_f32_fast(u, inv_lambd, eps_minus):
log = tl.where(u >= 1.0 + eps_minus, eps_minus, safe_fast_log_f32(u))
return -inv_lambd * log


# Iluvatar uses the precise version for numerical stability
transform_exponential_f32 = transform_exponential_f32_precise


@triton.jit
def transform_exponential_f64(u, inv_lambd, eps_minus):
log = tl.where(u >= 1.0 + eps_minus, eps_minus, safe_fast_log_f64(u))
return -inv_lambd * log


@libentry()
@libtuner(
configs=[
triton.Config({"BLOCK": 64}, num_warps=2, num_stages=2),
triton.Config({"BLOCK": 128}, num_warps=2, num_stages=2),
triton.Config({"BLOCK": 256}, num_warps=4, num_stages=2),
triton.Config({"BLOCK": 512}, num_warps=4, num_stages=3),
triton.Config({"BLOCK": 1024}, num_warps=4, num_stages=2),
triton.Config({"BLOCK": 2048}, num_warps=8, num_stages=2),
],
key=["N"],
)
@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"])
def fused_exponential_kernel_f32(
out_ptr, N, inv_lambd, eps_minus, philox_seed, philox_offset, BLOCK: tl.constexpr
):
philox_seed = philox_seed.to(tl.int64)
philox_offset = philox_offset.to(tl.int64)
c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)

pid = tl.program_id(0)
i = pid * BLOCK + tl.arange(0, BLOCK)
c0 += i
z = c0 * 0
r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, z, z)

y0 = transform_exponential_f32(uint_to_uniform_float(r0), inv_lambd, eps_minus)
y1 = transform_exponential_f32(uint_to_uniform_float(r1), inv_lambd, eps_minus)
y2 = transform_exponential_f32(uint_to_uniform_float(r2), inv_lambd, eps_minus)
y3 = transform_exponential_f32(uint_to_uniform_float(r3), inv_lambd, eps_minus)

start = pid.to(tl.uint64) * BLOCK * 4
off0 = start + tl.arange(0, BLOCK)
off1 = off0 + BLOCK
off2 = off1 + BLOCK
off3 = off2 + BLOCK

tl.store(out_ptr + off0, y0, mask=off0 < N)
tl.store(out_ptr + off1, y1, mask=off1 < N)
tl.store(out_ptr + off2, y2, mask=off2 < N)
tl.store(out_ptr + off3, y3, mask=off3 < N)


@libentry()
@libtuner(
configs=[
triton.Config({"BLOCK": 64}, num_warps=2, num_stages=2),
triton.Config({"BLOCK": 128}, num_warps=2, num_stages=2),
triton.Config({"BLOCK": 256}, num_warps=4, num_stages=2),
triton.Config({"BLOCK": 512}, num_warps=4, num_stages=3),
triton.Config({"BLOCK": 1024}, num_warps=4, num_stages=2),
triton.Config({"BLOCK": 2048}, num_warps=8, num_stages=2),
],
key=["N"],
)
@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"])
def fused_exponential_kernel_f64(
out_ptr, N, inv_lambd, eps_minus, philox_seed, philox_offset, BLOCK: tl.constexpr
):
philox_seed = philox_seed.to(tl.int64)
philox_offset = philox_offset.to(tl.int64)
c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)

pid = tl.program_id(0)
i = pid * BLOCK + tl.arange(0, BLOCK)
c0 += i
z = c0 * 0
r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, z, z)

u0 = uint_to_uniform_float(paste_u64(r0, r2))
u1 = uint_to_uniform_float(paste_u64(r1, r3))

y0 = transform_exponential_f64(u0, inv_lambd, eps_minus)
y1 = transform_exponential_f64(u1, inv_lambd, eps_minus)

start = pid.to(tl.uint64) * BLOCK * 2
off0 = start + tl.arange(0, BLOCK)
off1 = off0 + BLOCK

tl.store(out_ptr + off0, y0, mask=off0 < N)
tl.store(out_ptr + off1, y1, mask=off1 < N)


def exponential_(x, lambd: float = 1.0, *, generator=None):
logger.debug("GEMS_ILUVATAR EXPONENTIAL_")

dtype = x.dtype
device = x.device
inplace = x.is_contiguous()
assert dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64)

N = x.numel()

# Handle empty tensor
if N == 0:
return x

inv_lambd = 1.0 / lambd
eps_minus = -0.5 * torch.finfo(dtype).eps

out = x if inplace else torch.empty_like(x)

if dtype is torch.float64:
UNROLL = 2
grid = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),)
increment = triton.cdiv(N, UNROLL)
philox_seed, philox_offset = philox_backend_seed_offset(
increment, generator=generator
)
with torch_device_fn.device(device):
fused_exponential_kernel_f64[grid](
out, N, inv_lambd, eps_minus, philox_seed, philox_offset
)
else:
UNROLL = 4
grid = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),)
increment = triton.cdiv(N, UNROLL)
philox_seed, philox_offset = philox_backend_seed_offset(
increment, generator=generator
)
with torch_device_fn.device(device):
fused_exponential_kernel_f32[grid](
out, N, inv_lambd, eps_minus, philox_seed, philox_offset
)

if not inplace:
x.copy_(out)
return x
130 changes: 130 additions & 0 deletions src/flag_gems/runtime/backend/_iluvatar/ops/pow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import logging

import torch
import triton
import triton.language as tl

from flag_gems.utils import libentry
from flag_gems.utils.shape_utils import volume

logger = logging.getLogger(__name__)


@libentry()
@triton.jit
def pow_scalar_kernel(
output,
input_exponent,
sbase,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
"""
Triton kernel for pow_scalar: output = sbase ** input_exponent
"""
pid = tl.program_id(0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements

# Load exponent values
exponent = tl.load(input_exponent + offsets, mask=mask, other=0.0)

# Compute sbase ** exponent
result = tl.pow(sbase.to(tl.float32), exponent.to(tl.float32))

# Store result
tl.store(output + offsets, result, mask=mask)


@libentry()
@triton.jit
def pow_scalar_inplace_kernel(
input_output,
sbase,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
"""
Triton kernel for in-place pow_scalar_: input_output = sbase ** input_output
"""
pid = tl.program_id(0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements

# Load exponent values
exponent = tl.load(input_output + offsets, mask=mask, other=0.0)

# Compute sbase ** exponent
result = tl.pow(sbase.to(tl.float32), exponent.to(tl.float32))

# Store result in-place
tl.store(input_output + offsets, result, mask=mask)


def pow_scalar(A, exponent):
"""
Computes base^exponent where base is a scalar and exponent is a tensor.

Optimized Triton kernel for Iluvatar platform with BLOCK_SIZE=2048.

Args:
A: Scalar base value
exponent: Exponent tensor

Returns:
Output tensor with same shape as exponent
"""
logger.debug("GEMS_ILUVATAR POW_SCALAR")

# Handle empty tensor
if volume(exponent.shape) == 0:
return torch.empty_like(exponent)

output = torch.empty_like(exponent)
n_elements = volume(exponent.shape)

# Convert scalar base to float32 for computation
sbase = float(A)

# Grid size
BLOCK_SIZE = 2048
grid = (triton.cdiv(n_elements, BLOCK_SIZE),)

# Launch kernel
pow_scalar_kernel[grid](output, exponent, sbase, n_elements, BLOCK_SIZE=BLOCK_SIZE)

return output


def pow_scalar_(A, exponent):
"""
In-place version of pow_scalar.

Args:
A: Scalar base value
exponent: Exponent tensor (modified in-place)

Returns:
The modified exponent tensor
"""
logger.debug("GEMS_ILUVATAR POW_SCALAR_")

# Handle empty tensor
if volume(exponent.shape) == 0:
return exponent

n_elements = volume(exponent.shape)

# Convert scalar base to float32 for computation
sbase = float(A)

# Grid size
BLOCK_SIZE = 2048
grid = (triton.cdiv(n_elements, BLOCK_SIZE),)

# Launch kernel
pow_scalar_inplace_kernel[grid](exponent, sbase, n_elements, BLOCK_SIZE=BLOCK_SIZE)

return exponent
Loading