diff --git a/src/flag_gems/runtime/backend/_iluvatar/ops/__init__.py b/src/flag_gems/runtime/backend/_iluvatar/ops/__init__.py index 7ee6d04793..f1e764731f 100644 --- a/src/flag_gems/runtime/backend/_iluvatar/ops/__init__.py +++ b/src/flag_gems/runtime/backend/_iluvatar/ops/__init__.py @@ -1,6 +1,15 @@ from .div import div_mode, div_mode_ +from .exponential_ import exponential_ +from .pow import pow_scalar, pow_scalar_ +from .true_divide import true_divide, true_divide_, true_divide_out __all__ = [ "div_mode", "div_mode_", + "exponential_", + "pow_scalar", + "pow_scalar_", + "true_divide", + "true_divide_out", + "true_divide_", ] diff --git a/src/flag_gems/runtime/backend/_iluvatar/ops/exponential_.py b/src/flag_gems/runtime/backend/_iluvatar/ops/exponential_.py new file mode 100644 index 0000000000..e99a953d4c --- /dev/null +++ b/src/flag_gems/runtime/backend/_iluvatar/ops/exponential_.py @@ -0,0 +1,205 @@ +import logging + +import torch +import triton +import triton.language as tl + +from flag_gems.runtime import torch_device_fn +from flag_gems.utils import libentry, libtuner +from flag_gems.utils.random_utils import ( + philox_backend_seed_offset, + uint_to_uniform_float, +) + +logger = logging.getLogger(__name__) + + +@triton.jit +def safe_fast_log_f32(x): + min_normal = (x * 0.0 + 1.17549435e-38).to(tl.float32) + max_u = x * 0.0 + 0.99999994 + x = tl.minimum(tl.maximum(x, min_normal), max_u) + bits = x.to(tl.int32, bitcast=True) + exponent = (bits >> 23) - 127 + mantissa = (bits & 0x7FFFFF).to(tl.float32) * (1.0 / 8388608.0) + 1.0 + m1 = mantissa - 1.0 + return ( + m1 * (1.0 + m1 * (-0.5 + m1 * (0.3333333333 - m1 * 0.25))) + + exponent.to(tl.float32) * 0.6931471805599453 + ) + + +@triton.jit +def safe_fast_log_f64(x): + min_normal = x * 0.0 + 2.2250738585072014e-308 + max_u = x * 0.0 + (1.0 - 2.220446049250313e-16) + x = tl.minimum(tl.maximum(x, min_normal), max_u) + bits = x.to(tl.int64, bitcast=True) + exponent = (bits >> 52) - 1023 + mantissa = (bits & 0x000FFFFFFFFFFFFF).to(tl.float64) * ( + 1.0 / 4503599627370496.0 + ) + 1.0 + m1 = mantissa - 1.0 + return ( + m1 * (1.0 + m1 * (-0.5 + m1 * (0.3333333333333333 - m1 * 0.25))) + + exponent.to(tl.float64) * 0.6931471805599453 + ) + + +@triton.jit +def paste_u64(hi: tl.uint32, lo: tl.uint32): + return (hi.to(tl.uint64) << 32) | lo.to(tl.uint64) + + +@triton.jit +def transform_exponential_f32_precise(u, inv_lambd, eps_minus): + log = tl.where(u >= 1.0 + eps_minus, eps_minus, tl.math.log(u)) + return -inv_lambd * log + + +@triton.jit +def transform_exponential_f32_fast(u, inv_lambd, eps_minus): + log = tl.where(u >= 1.0 + eps_minus, eps_minus, safe_fast_log_f32(u)) + return -inv_lambd * log + + +# Iluvatar uses the precise version for numerical stability +transform_exponential_f32 = transform_exponential_f32_precise + + +@triton.jit +def transform_exponential_f64(u, inv_lambd, eps_minus): + log = tl.where(u >= 1.0 + eps_minus, eps_minus, safe_fast_log_f64(u)) + return -inv_lambd * log + + +@libentry() +@libtuner( + configs=[ + triton.Config({"BLOCK": 64}, num_warps=2, num_stages=2), + triton.Config({"BLOCK": 128}, num_warps=2, num_stages=2), + triton.Config({"BLOCK": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK": 512}, num_warps=4, num_stages=3), + triton.Config({"BLOCK": 1024}, num_warps=4, num_stages=2), + triton.Config({"BLOCK": 2048}, num_warps=8, num_stages=2), + ], + key=["N"], +) +@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"]) +def fused_exponential_kernel_f32( + out_ptr, N, inv_lambd, eps_minus, philox_seed, philox_offset, BLOCK: tl.constexpr +): + philox_seed = philox_seed.to(tl.int64) + philox_offset = philox_offset.to(tl.int64) + c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) + c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) + + pid = tl.program_id(0) + i = pid * BLOCK + tl.arange(0, BLOCK) + c0 += i + z = c0 * 0 + r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, z, z) + + y0 = transform_exponential_f32(uint_to_uniform_float(r0), inv_lambd, eps_minus) + y1 = transform_exponential_f32(uint_to_uniform_float(r1), inv_lambd, eps_minus) + y2 = transform_exponential_f32(uint_to_uniform_float(r2), inv_lambd, eps_minus) + y3 = transform_exponential_f32(uint_to_uniform_float(r3), inv_lambd, eps_minus) + + start = pid.to(tl.uint64) * BLOCK * 4 + off0 = start + tl.arange(0, BLOCK) + off1 = off0 + BLOCK + off2 = off1 + BLOCK + off3 = off2 + BLOCK + + tl.store(out_ptr + off0, y0, mask=off0 < N) + tl.store(out_ptr + off1, y1, mask=off1 < N) + tl.store(out_ptr + off2, y2, mask=off2 < N) + tl.store(out_ptr + off3, y3, mask=off3 < N) + + +@libentry() +@libtuner( + configs=[ + triton.Config({"BLOCK": 64}, num_warps=2, num_stages=2), + triton.Config({"BLOCK": 128}, num_warps=2, num_stages=2), + triton.Config({"BLOCK": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK": 512}, num_warps=4, num_stages=3), + triton.Config({"BLOCK": 1024}, num_warps=4, num_stages=2), + triton.Config({"BLOCK": 2048}, num_warps=8, num_stages=2), + ], + key=["N"], +) +@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"]) +def fused_exponential_kernel_f64( + out_ptr, N, inv_lambd, eps_minus, philox_seed, philox_offset, BLOCK: tl.constexpr +): + philox_seed = philox_seed.to(tl.int64) + philox_offset = philox_offset.to(tl.int64) + c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) + c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) + + pid = tl.program_id(0) + i = pid * BLOCK + tl.arange(0, BLOCK) + c0 += i + z = c0 * 0 + r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, z, z) + + u0 = uint_to_uniform_float(paste_u64(r0, r2)) + u1 = uint_to_uniform_float(paste_u64(r1, r3)) + + y0 = transform_exponential_f64(u0, inv_lambd, eps_minus) + y1 = transform_exponential_f64(u1, inv_lambd, eps_minus) + + start = pid.to(tl.uint64) * BLOCK * 2 + off0 = start + tl.arange(0, BLOCK) + off1 = off0 + BLOCK + + tl.store(out_ptr + off0, y0, mask=off0 < N) + tl.store(out_ptr + off1, y1, mask=off1 < N) + + +def exponential_(x, lambd: float = 1.0, *, generator=None): + logger.debug("GEMS_ILUVATAR EXPONENTIAL_") + + dtype = x.dtype + device = x.device + inplace = x.is_contiguous() + assert dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64) + + N = x.numel() + + # Handle empty tensor + if N == 0: + return x + + inv_lambd = 1.0 / lambd + eps_minus = -0.5 * torch.finfo(dtype).eps + + out = x if inplace else torch.empty_like(x) + + if dtype is torch.float64: + UNROLL = 2 + grid = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),) + increment = triton.cdiv(N, UNROLL) + philox_seed, philox_offset = philox_backend_seed_offset( + increment, generator=generator + ) + with torch_device_fn.device(device): + fused_exponential_kernel_f64[grid]( + out, N, inv_lambd, eps_minus, philox_seed, philox_offset + ) + else: + UNROLL = 4 + grid = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),) + increment = triton.cdiv(N, UNROLL) + philox_seed, philox_offset = philox_backend_seed_offset( + increment, generator=generator + ) + with torch_device_fn.device(device): + fused_exponential_kernel_f32[grid]( + out, N, inv_lambd, eps_minus, philox_seed, philox_offset + ) + + if not inplace: + x.copy_(out) + return x diff --git a/src/flag_gems/runtime/backend/_iluvatar/ops/pow.py b/src/flag_gems/runtime/backend/_iluvatar/ops/pow.py new file mode 100644 index 0000000000..4ba268073c --- /dev/null +++ b/src/flag_gems/runtime/backend/_iluvatar/ops/pow.py @@ -0,0 +1,130 @@ +import logging + +import torch +import triton +import triton.language as tl + +from flag_gems.utils import libentry +from flag_gems.utils.shape_utils import volume + +logger = logging.getLogger(__name__) + + +@libentry() +@triton.jit +def pow_scalar_kernel( + output, + input_exponent, + sbase, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + """ + Triton kernel for pow_scalar: output = sbase ** input_exponent + """ + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + # Load exponent values + exponent = tl.load(input_exponent + offsets, mask=mask, other=0.0) + + # Compute sbase ** exponent + result = tl.pow(sbase.to(tl.float32), exponent.to(tl.float32)) + + # Store result + tl.store(output + offsets, result, mask=mask) + + +@libentry() +@triton.jit +def pow_scalar_inplace_kernel( + input_output, + sbase, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + """ + Triton kernel for in-place pow_scalar_: input_output = sbase ** input_output + """ + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + # Load exponent values + exponent = tl.load(input_output + offsets, mask=mask, other=0.0) + + # Compute sbase ** exponent + result = tl.pow(sbase.to(tl.float32), exponent.to(tl.float32)) + + # Store result in-place + tl.store(input_output + offsets, result, mask=mask) + + +def pow_scalar(A, exponent): + """ + Computes base^exponent where base is a scalar and exponent is a tensor. + + Optimized Triton kernel for Iluvatar platform with BLOCK_SIZE=2048. + + Args: + A: Scalar base value + exponent: Exponent tensor + + Returns: + Output tensor with same shape as exponent + """ + logger.debug("GEMS_ILUVATAR POW_SCALAR") + + # Handle empty tensor + if volume(exponent.shape) == 0: + return torch.empty_like(exponent) + + output = torch.empty_like(exponent) + n_elements = volume(exponent.shape) + + # Convert scalar base to float32 for computation + sbase = float(A) + + # Grid size + BLOCK_SIZE = 2048 + grid = (triton.cdiv(n_elements, BLOCK_SIZE),) + + # Launch kernel + pow_scalar_kernel[grid](output, exponent, sbase, n_elements, BLOCK_SIZE=BLOCK_SIZE) + + return output + + +def pow_scalar_(A, exponent): + """ + In-place version of pow_scalar. + + Args: + A: Scalar base value + exponent: Exponent tensor (modified in-place) + + Returns: + The modified exponent tensor + """ + logger.debug("GEMS_ILUVATAR POW_SCALAR_") + + # Handle empty tensor + if volume(exponent.shape) == 0: + return exponent + + n_elements = volume(exponent.shape) + + # Convert scalar base to float32 for computation + sbase = float(A) + + # Grid size + BLOCK_SIZE = 2048 + grid = (triton.cdiv(n_elements, BLOCK_SIZE),) + + # Launch kernel + pow_scalar_inplace_kernel[grid](exponent, sbase, n_elements, BLOCK_SIZE=BLOCK_SIZE) + + return exponent