diff --git a/src/flag_gems/runtime/backend/_iluvatar/ops/__init__.py b/src/flag_gems/runtime/backend/_iluvatar/ops/__init__.py index 7ee6d04793..f7d38f9573 100644 --- a/src/flag_gems/runtime/backend/_iluvatar/ops/__init__.py +++ b/src/flag_gems/runtime/backend/_iluvatar/ops/__init__.py @@ -1,6 +1,25 @@ +from .add import add, add_ from .div import div_mode, div_mode_ +from .exponential_ import exponential_ +from .mul import mul +from .pow import pow_scalar, pow_scalar_ +from .sub import sub, sub_ +from .true_divide import true_divide, true_divide_, true_divide_out +from .view_as_complex import view_as_complex __all__ = [ + "add", + "add_", "div_mode", "div_mode_", + "exponential_", + "mul", + "pow_scalar", + "pow_scalar_", + "sub", + "sub_", + "true_divide", + "true_divide_out", + "true_divide_", + "view_as_complex", ] diff --git a/src/flag_gems/runtime/backend/_iluvatar/ops/add.py b/src/flag_gems/runtime/backend/_iluvatar/ops/add.py new file mode 100644 index 0000000000..14cf46de92 --- /dev/null +++ b/src/flag_gems/runtime/backend/_iluvatar/ops/add.py @@ -0,0 +1,57 @@ +import logging + +import torch +import triton + +from flag_gems.utils import pointwise_dynamic + +logger = logging.getLogger(__name__) + + +@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")]) +@triton.jit +def add_func(x, y, alpha): + return x + y * alpha + + +@pointwise_dynamic( + is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")] +) +@triton.jit +def add_func_tensor_scalar(x, y, alpha): + return x + y * alpha + + +@pointwise_dynamic( + is_tensor=[False, True, False], promotion_methods=[(0, 1, "DEFAULT")] +) +@triton.jit +def add_func_scalar_tensor(x, y, alpha): + return x + y * alpha + + +def add(A, B, *, alpha=1): + logger.debug("GEMS ILUVATAR ADD") + if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): + if B.device != A.device: + B = B.to(A.device) + return add_func(A, B, alpha) + elif isinstance(A, torch.Tensor): + return add_func_tensor_scalar(A, B, alpha) + elif isinstance(B, torch.Tensor): + return add_func_scalar_tensor(A, B, alpha) + else: + return torch.tensor(A + B * alpha) + + +def add_(A, B, *, alpha=1): + logger.debug("GEMS ILUVATAR ADD_") + if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): + if B.device != A.device: + B = B.to(A.device) + return add_func(A, B, alpha, out0=A) + elif isinstance(A, torch.Tensor): + return add_func_tensor_scalar(A, B, alpha, out0=A) + # Note: scalar_tensor case not supported for in-place + else: + raise ValueError("Unreachable.") diff --git a/src/flag_gems/runtime/backend/_iluvatar/ops/exponential_.py b/src/flag_gems/runtime/backend/_iluvatar/ops/exponential_.py new file mode 100644 index 0000000000..e99a953d4c --- /dev/null +++ b/src/flag_gems/runtime/backend/_iluvatar/ops/exponential_.py @@ -0,0 +1,205 @@ +import logging + +import torch +import triton +import triton.language as tl + +from flag_gems.runtime import torch_device_fn +from flag_gems.utils import libentry, libtuner +from flag_gems.utils.random_utils import ( + philox_backend_seed_offset, + uint_to_uniform_float, +) + +logger = logging.getLogger(__name__) + + +@triton.jit +def safe_fast_log_f32(x): + min_normal = (x * 0.0 + 1.17549435e-38).to(tl.float32) + max_u = x * 0.0 + 0.99999994 + x = tl.minimum(tl.maximum(x, min_normal), max_u) + bits = x.to(tl.int32, bitcast=True) + exponent = (bits >> 23) - 127 + mantissa = (bits & 0x7FFFFF).to(tl.float32) * (1.0 / 8388608.0) + 1.0 + m1 = mantissa - 1.0 + return ( + m1 * (1.0 + m1 * (-0.5 + m1 * (0.3333333333 - m1 * 0.25))) + + exponent.to(tl.float32) * 0.6931471805599453 + ) + + +@triton.jit +def safe_fast_log_f64(x): + min_normal = x * 0.0 + 2.2250738585072014e-308 + max_u = x * 0.0 + (1.0 - 2.220446049250313e-16) + x = tl.minimum(tl.maximum(x, min_normal), max_u) + bits = x.to(tl.int64, bitcast=True) + exponent = (bits >> 52) - 1023 + mantissa = (bits & 0x000FFFFFFFFFFFFF).to(tl.float64) * ( + 1.0 / 4503599627370496.0 + ) + 1.0 + m1 = mantissa - 1.0 + return ( + m1 * (1.0 + m1 * (-0.5 + m1 * (0.3333333333333333 - m1 * 0.25))) + + exponent.to(tl.float64) * 0.6931471805599453 + ) + + +@triton.jit +def paste_u64(hi: tl.uint32, lo: tl.uint32): + return (hi.to(tl.uint64) << 32) | lo.to(tl.uint64) + + +@triton.jit +def transform_exponential_f32_precise(u, inv_lambd, eps_minus): + log = tl.where(u >= 1.0 + eps_minus, eps_minus, tl.math.log(u)) + return -inv_lambd * log + + +@triton.jit +def transform_exponential_f32_fast(u, inv_lambd, eps_minus): + log = tl.where(u >= 1.0 + eps_minus, eps_minus, safe_fast_log_f32(u)) + return -inv_lambd * log + + +# Iluvatar uses the precise version for numerical stability +transform_exponential_f32 = transform_exponential_f32_precise + + +@triton.jit +def transform_exponential_f64(u, inv_lambd, eps_minus): + log = tl.where(u >= 1.0 + eps_minus, eps_minus, safe_fast_log_f64(u)) + return -inv_lambd * log + + +@libentry() +@libtuner( + configs=[ + triton.Config({"BLOCK": 64}, num_warps=2, num_stages=2), + triton.Config({"BLOCK": 128}, num_warps=2, num_stages=2), + triton.Config({"BLOCK": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK": 512}, num_warps=4, num_stages=3), + triton.Config({"BLOCK": 1024}, num_warps=4, num_stages=2), + triton.Config({"BLOCK": 2048}, num_warps=8, num_stages=2), + ], + key=["N"], +) +@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"]) +def fused_exponential_kernel_f32( + out_ptr, N, inv_lambd, eps_minus, philox_seed, philox_offset, BLOCK: tl.constexpr +): + philox_seed = philox_seed.to(tl.int64) + philox_offset = philox_offset.to(tl.int64) + c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) + c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) + + pid = tl.program_id(0) + i = pid * BLOCK + tl.arange(0, BLOCK) + c0 += i + z = c0 * 0 + r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, z, z) + + y0 = transform_exponential_f32(uint_to_uniform_float(r0), inv_lambd, eps_minus) + y1 = transform_exponential_f32(uint_to_uniform_float(r1), inv_lambd, eps_minus) + y2 = transform_exponential_f32(uint_to_uniform_float(r2), inv_lambd, eps_minus) + y3 = transform_exponential_f32(uint_to_uniform_float(r3), inv_lambd, eps_minus) + + start = pid.to(tl.uint64) * BLOCK * 4 + off0 = start + tl.arange(0, BLOCK) + off1 = off0 + BLOCK + off2 = off1 + BLOCK + off3 = off2 + BLOCK + + tl.store(out_ptr + off0, y0, mask=off0 < N) + tl.store(out_ptr + off1, y1, mask=off1 < N) + tl.store(out_ptr + off2, y2, mask=off2 < N) + tl.store(out_ptr + off3, y3, mask=off3 < N) + + +@libentry() +@libtuner( + configs=[ + triton.Config({"BLOCK": 64}, num_warps=2, num_stages=2), + triton.Config({"BLOCK": 128}, num_warps=2, num_stages=2), + triton.Config({"BLOCK": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK": 512}, num_warps=4, num_stages=3), + triton.Config({"BLOCK": 1024}, num_warps=4, num_stages=2), + triton.Config({"BLOCK": 2048}, num_warps=8, num_stages=2), + ], + key=["N"], +) +@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"]) +def fused_exponential_kernel_f64( + out_ptr, N, inv_lambd, eps_minus, philox_seed, philox_offset, BLOCK: tl.constexpr +): + philox_seed = philox_seed.to(tl.int64) + philox_offset = philox_offset.to(tl.int64) + c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) + c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) + + pid = tl.program_id(0) + i = pid * BLOCK + tl.arange(0, BLOCK) + c0 += i + z = c0 * 0 + r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, z, z) + + u0 = uint_to_uniform_float(paste_u64(r0, r2)) + u1 = uint_to_uniform_float(paste_u64(r1, r3)) + + y0 = transform_exponential_f64(u0, inv_lambd, eps_minus) + y1 = transform_exponential_f64(u1, inv_lambd, eps_minus) + + start = pid.to(tl.uint64) * BLOCK * 2 + off0 = start + tl.arange(0, BLOCK) + off1 = off0 + BLOCK + + tl.store(out_ptr + off0, y0, mask=off0 < N) + tl.store(out_ptr + off1, y1, mask=off1 < N) + + +def exponential_(x, lambd: float = 1.0, *, generator=None): + logger.debug("GEMS_ILUVATAR EXPONENTIAL_") + + dtype = x.dtype + device = x.device + inplace = x.is_contiguous() + assert dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64) + + N = x.numel() + + # Handle empty tensor + if N == 0: + return x + + inv_lambd = 1.0 / lambd + eps_minus = -0.5 * torch.finfo(dtype).eps + + out = x if inplace else torch.empty_like(x) + + if dtype is torch.float64: + UNROLL = 2 + grid = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),) + increment = triton.cdiv(N, UNROLL) + philox_seed, philox_offset = philox_backend_seed_offset( + increment, generator=generator + ) + with torch_device_fn.device(device): + fused_exponential_kernel_f64[grid]( + out, N, inv_lambd, eps_minus, philox_seed, philox_offset + ) + else: + UNROLL = 4 + grid = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),) + increment = triton.cdiv(N, UNROLL) + philox_seed, philox_offset = philox_backend_seed_offset( + increment, generator=generator + ) + with torch_device_fn.device(device): + fused_exponential_kernel_f32[grid]( + out, N, inv_lambd, eps_minus, philox_seed, philox_offset + ) + + if not inplace: + x.copy_(out) + return x diff --git a/src/flag_gems/runtime/backend/_iluvatar/ops/pow.py b/src/flag_gems/runtime/backend/_iluvatar/ops/pow.py new file mode 100644 index 0000000000..a8d9d1d199 --- /dev/null +++ b/src/flag_gems/runtime/backend/_iluvatar/ops/pow.py @@ -0,0 +1,47 @@ +import logging + +import triton +import triton.language as tl + +from flag_gems.utils import pointwise_dynamic, tl_extra_shim + +_pow = tl_extra_shim.pow +logger = logging.getLogger(__name__) + + +@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "BOOL_TO_LONG")]) +@triton.jit +def pow_func_scalar_tensor(x, exponent): + return _pow(x.to(tl.float32), exponent.to(tl.float32)) + + +def pow_scalar(A, exponent): + """ + Computes base^exponent where base is a scalar and exponent is a tensor. + + Uses FlagGems standard pointwise_dynamic for hardware compatibility. + + Args: + A: Scalar base value + exponent: Exponent tensor + + Returns: + Output tensor with same shape as exponent + """ + logger.debug("GEMS_ILUVATAR POW_SCALAR") + return pow_func_scalar_tensor(A, exponent) + + +def pow_scalar_(A, exponent): + """ + In-place version of pow_scalar. + + Args: + A: Scalar base value + exponent: Exponent tensor (modified in-place) + + Returns: + The modified exponent tensor + """ + logger.debug("GEMS_ILUVATAR POW_SCALAR_") + return pow_func_scalar_tensor(A, exponent, out0=exponent) diff --git a/src/flag_gems/runtime/backend/_iluvatar/ops/sub.py b/src/flag_gems/runtime/backend/_iluvatar/ops/sub.py new file mode 100644 index 0000000000..23be7ab217 --- /dev/null +++ b/src/flag_gems/runtime/backend/_iluvatar/ops/sub.py @@ -0,0 +1,244 @@ +import logging + +import torch +import triton +import triton.language as tl + +from flag_gems.utils import libentry + +logger = logging.getLogger(__name__) + + +@libentry() +@triton.jit +def sub_kernel( + x_ptr, + y_ptr, + output_ptr, + alpha, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + result = x - y * alpha + tl.store(output_ptr + offsets, result, mask=mask) + + +@libentry() +@triton.jit +def sub_scalar_kernel( + x_ptr, + y_scalar, + output_ptr, + alpha, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + x = tl.load(x_ptr + offsets, mask=mask) + result = x - y_scalar * alpha + tl.store(output_ptr + offsets, result, mask=mask) + + +@libentry() +@triton.jit +def sub_scalar_tensor_kernel( + x_scalar, + y_ptr, + output_ptr, + alpha, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + """Compute: x_scalar - y * alpha""" + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + y = tl.load(y_ptr + offsets, mask=mask) + result = x_scalar - y * alpha + tl.store(output_ptr + offsets, result, mask=mask) + + +@libentry() +@triton.jit +def sub__kernel( + x_ptr, + y_ptr, + alpha, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + """In-place subtraction: x = x - y * alpha""" + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + result = x - y * alpha + tl.store(x_ptr + offsets, result, mask=mask) + + +@libentry() +@triton.jit +def sub__scalar_kernel( + x_ptr, + y_scalar, + alpha, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + """In-place subtraction with scalar: x = x - y_scalar * alpha""" + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + x = tl.load(x_ptr + offsets, mask=mask) + result = x - y_scalar * alpha + tl.store(x_ptr + offsets, result, mask=mask) + + +def sub(A, B, *, alpha=1): + """Subtraction operator: output = A - B * alpha + + Supports: + - tensor - tensor + - tensor - scalar + - scalar - tensor + - scalar - scalar + """ + logger.debug("GEMS_ILUVATAR SUB") + + # Handle alpha + if isinstance(alpha, torch.Tensor): + alpha = alpha.item() + alpha = float(alpha) + + # Both are tensors + if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): + # Handle broadcasting + A, B = torch.broadcast_tensors(A, B) + A = A.contiguous() + B = B.contiguous() + n_elements = A.numel() + + # Empty tensor protection + if n_elements == 0: + return torch.empty_like(A) + + # Handle 0-dimensional tensor (scalar tensor) - use elementwise loop + if A.dim() == 0: + # For 0-dim tensors, just use Python computation + result = A.item() - B.item() * alpha + return torch.tensor(result, dtype=A.dtype, device=A.device) + + output = torch.empty_like(A) + BLOCK_SIZE = 2048 + grid = (triton.cdiv(n_elements, BLOCK_SIZE),) + sub_kernel[grid](A, B, output, alpha, n_elements, BLOCK_SIZE=BLOCK_SIZE) + return output + + # A is tensor, B is scalar + elif isinstance(A, torch.Tensor): + A = A.contiguous() + n_elements = A.numel() + + if n_elements == 0: + return torch.empty_like(A) + + # Handle 0-dimensional tensor + if A.dim() == 0: + result = A.item() - float(B) * alpha + return torch.tensor(result, dtype=A.dtype, device=A.device) + + output = torch.empty_like(A) + y_scalar = float(B) + BLOCK_SIZE = 2048 + grid = (triton.cdiv(n_elements, BLOCK_SIZE),) + sub_scalar_kernel[grid]( + A, y_scalar, output, alpha, n_elements, BLOCK_SIZE=BLOCK_SIZE + ) + return output + + # B is tensor, A is scalar + elif isinstance(B, torch.Tensor): + B = B.contiguous() + n_elements = B.numel() + + if n_elements == 0: + return torch.empty_like(B) + + # Handle 0-dimensional tensor + if B.dim() == 0: + result = float(A) - B.item() * alpha + return torch.tensor(result, dtype=B.dtype, device=B.device) + + output = torch.empty_like(B) + x_scalar = float(A) + BLOCK_SIZE = 2048 + grid = (triton.cdiv(n_elements, BLOCK_SIZE),) + sub_scalar_tensor_kernel[grid]( + x_scalar, B, output, alpha, n_elements, BLOCK_SIZE=BLOCK_SIZE + ) + return output + + # Both scalars + else: + return torch.tensor(A - B * alpha) + + +def sub_(A, B, *, alpha=1): + """In-place subtraction operator: A = A - B * alpha""" + logger.debug("GEMS_ILUVATAR SUB_") + + # Handle alpha + if isinstance(alpha, torch.Tensor): + alpha = alpha.item() + alpha = float(alpha) + + n_elements = A.numel() + + # Empty tensor protection + if n_elements == 0: + return A + + # Handle 0-dimensional tensor + if A.dim() == 0: + if isinstance(B, torch.Tensor): + result = A.item() - B.item() * alpha + else: + result = A.item() - float(B) * alpha + A.copy_(torch.tensor(result, dtype=A.dtype, device=A.device)) + return A + + BLOCK_SIZE = 2048 + + if isinstance(B, torch.Tensor): + # Handle broadcasting + if A.shape != B.shape: + B = B.expand_as(A) + B = B.contiguous() + A = A.contiguous() + grid = (triton.cdiv(n_elements, BLOCK_SIZE),) + sub__kernel[grid](A, B, alpha, n_elements, BLOCK_SIZE=BLOCK_SIZE) + else: + y_scalar = float(B) + A = A.contiguous() + grid = (triton.cdiv(n_elements, BLOCK_SIZE),) + sub__scalar_kernel[grid](A, y_scalar, alpha, n_elements, BLOCK_SIZE=BLOCK_SIZE) + + return A diff --git a/src/flag_gems/runtime/backend/_iluvatar/ops/view_as_complex.py b/src/flag_gems/runtime/backend/_iluvatar/ops/view_as_complex.py new file mode 100644 index 0000000000..040d4fc818 --- /dev/null +++ b/src/flag_gems/runtime/backend/_iluvatar/ops/view_as_complex.py @@ -0,0 +1,19 @@ +import logging + +import torch + + +def view_as_complex(input: torch.Tensor) -> torch.Tensor: + """ + Convert a real tensor with last dimension 2 to a complex tensor. + + Args: + input: Input tensor with shape (..., 2) and dtype float32 or float64 + + Returns: + Complex tensor with shape (...) and dtype complex64 or complex128 + """ + logging.debug("GEMS_ILUVATAR VIEW_AS_COMPLEX") + + # Use PyTorch's native implementation which is highly optimized + return torch.view_as_complex(input.contiguous())