diff --git a/src/flag_gems/runtime/backend/_iluvatar/ops/__init__.py b/src/flag_gems/runtime/backend/_iluvatar/ops/__init__.py index 7ee6d04793..5d1c90d0d6 100644 --- a/src/flag_gems/runtime/backend/_iluvatar/ops/__init__.py +++ b/src/flag_gems/runtime/backend/_iluvatar/ops/__init__.py @@ -1,6 +1,27 @@ +from .add import add, add_ from .div import div_mode, div_mode_ +from .exponential_ import exponential_ +from .mul import mul +from .ones import ones +from .pow import pow_scalar, pow_scalar_ +from .repeat import repeat +from .sub import sub, sub_ +from .true_divide import true_divide, true_divide_, true_divide_out __all__ = [ + "add", + "add_", "div_mode", "div_mode_", + "exponential_", + "mul", + "ones", + "pow_scalar", + "pow_scalar_", + "repeat", + "sub", + "sub_", + "true_divide", + "true_divide_out", + "true_divide_", ] diff --git a/src/flag_gems/runtime/backend/_iluvatar/ops/add.py b/src/flag_gems/runtime/backend/_iluvatar/ops/add.py new file mode 100644 index 0000000000..14cf46de92 --- /dev/null +++ b/src/flag_gems/runtime/backend/_iluvatar/ops/add.py @@ -0,0 +1,57 @@ +import logging + +import torch +import triton + +from flag_gems.utils import pointwise_dynamic + +logger = logging.getLogger(__name__) + + +@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")]) +@triton.jit +def add_func(x, y, alpha): + return x + y * alpha + + +@pointwise_dynamic( + is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")] +) +@triton.jit +def add_func_tensor_scalar(x, y, alpha): + return x + y * alpha + + +@pointwise_dynamic( + is_tensor=[False, True, False], promotion_methods=[(0, 1, "DEFAULT")] +) +@triton.jit +def add_func_scalar_tensor(x, y, alpha): + return x + y * alpha + + +def add(A, B, *, alpha=1): + logger.debug("GEMS ILUVATAR ADD") + if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): + if B.device != A.device: + B = B.to(A.device) + return add_func(A, B, alpha) + elif isinstance(A, torch.Tensor): + return add_func_tensor_scalar(A, B, alpha) + elif isinstance(B, torch.Tensor): + return add_func_scalar_tensor(A, B, alpha) + else: + return torch.tensor(A + B * alpha) + + +def add_(A, B, *, alpha=1): + logger.debug("GEMS ILUVATAR ADD_") + if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): + if B.device != A.device: + B = B.to(A.device) + return add_func(A, B, alpha, out0=A) + elif isinstance(A, torch.Tensor): + return add_func_tensor_scalar(A, B, alpha, out0=A) + # Note: scalar_tensor case not supported for in-place + else: + raise ValueError("Unreachable.") diff --git a/src/flag_gems/runtime/backend/_iluvatar/ops/exponential_.py b/src/flag_gems/runtime/backend/_iluvatar/ops/exponential_.py new file mode 100644 index 0000000000..e99a953d4c --- /dev/null +++ b/src/flag_gems/runtime/backend/_iluvatar/ops/exponential_.py @@ -0,0 +1,205 @@ +import logging + +import torch +import triton +import triton.language as tl + +from flag_gems.runtime import torch_device_fn +from flag_gems.utils import libentry, libtuner +from flag_gems.utils.random_utils import ( + philox_backend_seed_offset, + uint_to_uniform_float, +) + +logger = logging.getLogger(__name__) + + +@triton.jit +def safe_fast_log_f32(x): + min_normal = (x * 0.0 + 1.17549435e-38).to(tl.float32) + max_u = x * 0.0 + 0.99999994 + x = tl.minimum(tl.maximum(x, min_normal), max_u) + bits = x.to(tl.int32, bitcast=True) + exponent = (bits >> 23) - 127 + mantissa = (bits & 0x7FFFFF).to(tl.float32) * (1.0 / 8388608.0) + 1.0 + m1 = mantissa - 1.0 + return ( + m1 * (1.0 + m1 * (-0.5 + m1 * (0.3333333333 - m1 * 0.25))) + + exponent.to(tl.float32) * 0.6931471805599453 + ) + + +@triton.jit +def safe_fast_log_f64(x): + min_normal = x * 0.0 + 2.2250738585072014e-308 + max_u = x * 0.0 + (1.0 - 2.220446049250313e-16) + x = tl.minimum(tl.maximum(x, min_normal), max_u) + bits = x.to(tl.int64, bitcast=True) + exponent = (bits >> 52) - 1023 + mantissa = (bits & 0x000FFFFFFFFFFFFF).to(tl.float64) * ( + 1.0 / 4503599627370496.0 + ) + 1.0 + m1 = mantissa - 1.0 + return ( + m1 * (1.0 + m1 * (-0.5 + m1 * (0.3333333333333333 - m1 * 0.25))) + + exponent.to(tl.float64) * 0.6931471805599453 + ) + + +@triton.jit +def paste_u64(hi: tl.uint32, lo: tl.uint32): + return (hi.to(tl.uint64) << 32) | lo.to(tl.uint64) + + +@triton.jit +def transform_exponential_f32_precise(u, inv_lambd, eps_minus): + log = tl.where(u >= 1.0 + eps_minus, eps_minus, tl.math.log(u)) + return -inv_lambd * log + + +@triton.jit +def transform_exponential_f32_fast(u, inv_lambd, eps_minus): + log = tl.where(u >= 1.0 + eps_minus, eps_minus, safe_fast_log_f32(u)) + return -inv_lambd * log + + +# Iluvatar uses the precise version for numerical stability +transform_exponential_f32 = transform_exponential_f32_precise + + +@triton.jit +def transform_exponential_f64(u, inv_lambd, eps_minus): + log = tl.where(u >= 1.0 + eps_minus, eps_minus, safe_fast_log_f64(u)) + return -inv_lambd * log + + +@libentry() +@libtuner( + configs=[ + triton.Config({"BLOCK": 64}, num_warps=2, num_stages=2), + triton.Config({"BLOCK": 128}, num_warps=2, num_stages=2), + triton.Config({"BLOCK": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK": 512}, num_warps=4, num_stages=3), + triton.Config({"BLOCK": 1024}, num_warps=4, num_stages=2), + triton.Config({"BLOCK": 2048}, num_warps=8, num_stages=2), + ], + key=["N"], +) +@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"]) +def fused_exponential_kernel_f32( + out_ptr, N, inv_lambd, eps_minus, philox_seed, philox_offset, BLOCK: tl.constexpr +): + philox_seed = philox_seed.to(tl.int64) + philox_offset = philox_offset.to(tl.int64) + c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) + c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) + + pid = tl.program_id(0) + i = pid * BLOCK + tl.arange(0, BLOCK) + c0 += i + z = c0 * 0 + r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, z, z) + + y0 = transform_exponential_f32(uint_to_uniform_float(r0), inv_lambd, eps_minus) + y1 = transform_exponential_f32(uint_to_uniform_float(r1), inv_lambd, eps_minus) + y2 = transform_exponential_f32(uint_to_uniform_float(r2), inv_lambd, eps_minus) + y3 = transform_exponential_f32(uint_to_uniform_float(r3), inv_lambd, eps_minus) + + start = pid.to(tl.uint64) * BLOCK * 4 + off0 = start + tl.arange(0, BLOCK) + off1 = off0 + BLOCK + off2 = off1 + BLOCK + off3 = off2 + BLOCK + + tl.store(out_ptr + off0, y0, mask=off0 < N) + tl.store(out_ptr + off1, y1, mask=off1 < N) + tl.store(out_ptr + off2, y2, mask=off2 < N) + tl.store(out_ptr + off3, y3, mask=off3 < N) + + +@libentry() +@libtuner( + configs=[ + triton.Config({"BLOCK": 64}, num_warps=2, num_stages=2), + triton.Config({"BLOCK": 128}, num_warps=2, num_stages=2), + triton.Config({"BLOCK": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK": 512}, num_warps=4, num_stages=3), + triton.Config({"BLOCK": 1024}, num_warps=4, num_stages=2), + triton.Config({"BLOCK": 2048}, num_warps=8, num_stages=2), + ], + key=["N"], +) +@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"]) +def fused_exponential_kernel_f64( + out_ptr, N, inv_lambd, eps_minus, philox_seed, philox_offset, BLOCK: tl.constexpr +): + philox_seed = philox_seed.to(tl.int64) + philox_offset = philox_offset.to(tl.int64) + c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) + c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) + + pid = tl.program_id(0) + i = pid * BLOCK + tl.arange(0, BLOCK) + c0 += i + z = c0 * 0 + r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, z, z) + + u0 = uint_to_uniform_float(paste_u64(r0, r2)) + u1 = uint_to_uniform_float(paste_u64(r1, r3)) + + y0 = transform_exponential_f64(u0, inv_lambd, eps_minus) + y1 = transform_exponential_f64(u1, inv_lambd, eps_minus) + + start = pid.to(tl.uint64) * BLOCK * 2 + off0 = start + tl.arange(0, BLOCK) + off1 = off0 + BLOCK + + tl.store(out_ptr + off0, y0, mask=off0 < N) + tl.store(out_ptr + off1, y1, mask=off1 < N) + + +def exponential_(x, lambd: float = 1.0, *, generator=None): + logger.debug("GEMS_ILUVATAR EXPONENTIAL_") + + dtype = x.dtype + device = x.device + inplace = x.is_contiguous() + assert dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64) + + N = x.numel() + + # Handle empty tensor + if N == 0: + return x + + inv_lambd = 1.0 / lambd + eps_minus = -0.5 * torch.finfo(dtype).eps + + out = x if inplace else torch.empty_like(x) + + if dtype is torch.float64: + UNROLL = 2 + grid = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),) + increment = triton.cdiv(N, UNROLL) + philox_seed, philox_offset = philox_backend_seed_offset( + increment, generator=generator + ) + with torch_device_fn.device(device): + fused_exponential_kernel_f64[grid]( + out, N, inv_lambd, eps_minus, philox_seed, philox_offset + ) + else: + UNROLL = 4 + grid = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),) + increment = triton.cdiv(N, UNROLL) + philox_seed, philox_offset = philox_backend_seed_offset( + increment, generator=generator + ) + with torch_device_fn.device(device): + fused_exponential_kernel_f32[grid]( + out, N, inv_lambd, eps_minus, philox_seed, philox_offset + ) + + if not inplace: + x.copy_(out) + return x diff --git a/src/flag_gems/runtime/backend/_iluvatar/ops/pow.py b/src/flag_gems/runtime/backend/_iluvatar/ops/pow.py new file mode 100644 index 0000000000..a8d9d1d199 --- /dev/null +++ b/src/flag_gems/runtime/backend/_iluvatar/ops/pow.py @@ -0,0 +1,47 @@ +import logging + +import triton +import triton.language as tl + +from flag_gems.utils import pointwise_dynamic, tl_extra_shim + +_pow = tl_extra_shim.pow +logger = logging.getLogger(__name__) + + +@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "BOOL_TO_LONG")]) +@triton.jit +def pow_func_scalar_tensor(x, exponent): + return _pow(x.to(tl.float32), exponent.to(tl.float32)) + + +def pow_scalar(A, exponent): + """ + Computes base^exponent where base is a scalar and exponent is a tensor. + + Uses FlagGems standard pointwise_dynamic for hardware compatibility. + + Args: + A: Scalar base value + exponent: Exponent tensor + + Returns: + Output tensor with same shape as exponent + """ + logger.debug("GEMS_ILUVATAR POW_SCALAR") + return pow_func_scalar_tensor(A, exponent) + + +def pow_scalar_(A, exponent): + """ + In-place version of pow_scalar. + + Args: + A: Scalar base value + exponent: Exponent tensor (modified in-place) + + Returns: + The modified exponent tensor + """ + logger.debug("GEMS_ILUVATAR POW_SCALAR_") + return pow_func_scalar_tensor(A, exponent, out0=exponent) diff --git a/src/flag_gems/runtime/backend/_iluvatar/ops/repeat.py b/src/flag_gems/runtime/backend/_iluvatar/ops/repeat.py new file mode 100644 index 0000000000..85edd382cf --- /dev/null +++ b/src/flag_gems/runtime/backend/_iluvatar/ops/repeat.py @@ -0,0 +1,209 @@ +import logging + +import torch +import triton +import triton.language as tl + +from flag_gems.runtime import torch_device_fn + +logger = logging.getLogger(__name__) + + +@triton.jit +def repeat_kernel_3d( + inp_ptr, + out_ptr, + inp_s0, + inp_s1, + inp_s2, + out_s1, + out_s2, + inp_st0, + inp_st1, + rep_s2, + BLOCK_SIZE: tl.constexpr, +): + """3D repeat: each program handles one output row, writing all dim2 repeats.""" + row_id = tl.program_id(0) + + o0 = row_id // out_s1 + o1 = row_id % out_s1 + + i0 = o0 % inp_s0 + i1 = o1 % inp_s1 + + inp_row_base = i0 * inp_st0 + i1 * inp_st1 + out_row_base = row_id * out_s2 + + # Load input row once + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < inp_s2 + vals = tl.load(inp_ptr + inp_row_base + offsets, mask=mask) + + # Write to all dim2 repeats + for t in range(rep_s2): + tl.store(out_ptr + out_row_base + t * inp_s2 + offsets, vals, mask=mask) + + +@triton.jit +def repeat_kernel_3d_tiled( + inp_ptr, + out_ptr, + inp_s0, + inp_s1, + inp_s2, + out_s1, + out_s2, + inp_st0, + inp_st1, + BLOCK_SIZE: tl.constexpr, +): + """3D repeat tiled: each program handles one (row, tile) pair.""" + row_id = tl.program_id(0) + tile_id = tl.program_id(1) + + o0 = row_id // out_s1 + o1 = row_id % out_s1 + + i0 = o0 % inp_s0 + i1 = o1 % inp_s1 + + inp_row_base = i0 * inp_st0 + i1 * inp_st1 + out_row_base = row_id * out_s2 + tile_id * inp_s2 + + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < inp_s2 + vals = tl.load(inp_ptr + inp_row_base + offsets, mask=mask) + tl.store(out_ptr + out_row_base + offsets, vals, mask=mask) + + +def _next_power_of_2(n): + """Compute next power of 2.""" + n = int(n) + if n <= 0: + return 1 + n -= 1 + n |= n >> 1 + n |= n >> 2 + n |= n >> 4 + n |= n >> 8 + n |= n >> 16 + return n + 1 + + +def repeat(inp: torch.Tensor, *sizes) -> torch.Tensor: + """Repeat operator for Iluvatar platform. + + Repeat tensor elements along specified dimensions. Uses optimized Triton kernel + for 3D tensors with loop or tiled strategy based on repeat size. + + Args: + inp: Input tensor to be repeated + *sizes: The number of times to repeat this tensor along each dimension + + Returns: + Output tensor with repeated elements + """ + logger.debug("GEMS_ILUVATAR REPEAT") + + # Convert sizes to tuple if not already + if len(sizes) == 1 and isinstance(sizes[0], (list, tuple)): + sizes = tuple(sizes[0]) + else: + sizes = tuple(sizes) + + ndim = inp.ndim + sizes_ndim = len(sizes) + + if sizes_ndim > ndim: + for _ in range(sizes_ndim - ndim): + inp = inp.unsqueeze(0) + ndim = inp.ndim + + # 1D: PyTorch is already fast + if ndim == 1: + return inp.repeat(sizes) + + # Normalize to 3D for unified kernel path + orig_ndim = ndim + while ndim < 3: + inp = inp.unsqueeze(0) + sizes = (1,) + sizes + ndim += 1 + + # For >3D, fall back to PyTorch + if ndim > 3: + return inp.repeat(sizes) + + inp = inp.contiguous() + inp_shape = inp.shape + out_shape = ( + inp_shape[0] * sizes[0], + inp_shape[1] * sizes[1], + inp_shape[2] * sizes[2], + ) + + out = torch.empty(out_shape, dtype=inp.dtype, device=inp.device) + if out.numel() == 0: + return out + + num_rows = out_shape[0] * out_shape[1] + inp_s2 = inp_shape[2] + rep_s2 = sizes[2] + + BLOCK_SIZE = _next_power_of_2(inp_s2) + if BLOCK_SIZE < 128: + BLOCK_SIZE = 128 + if BLOCK_SIZE > 4096: + BLOCK_SIZE = 4096 + + # Fall back to PyTorch if single element is larger than BLOCK_SIZE + if inp_s2 > BLOCK_SIZE: + result = inp.repeat(sizes) + if orig_ndim == 2: + return result.reshape(out_shape[1], out_shape[2]) + return result + + num_warps = 4 if BLOCK_SIZE >= 512 else (2 if BLOCK_SIZE >= 128 else 1) + + with torch_device_fn.device(inp.device.index): + # For small rep_s2, use loop kernel (load once, store multiple times) + # For large rep_s2, use tiled kernel (better parallelism) + if rep_s2 <= 8: + grid = (num_rows,) + repeat_kernel_3d[grid]( + inp, + out, + inp_shape[0], + inp_shape[1], + inp_s2, + out_shape[1], + out_shape[2], + inp.stride(0), + inp.stride(1), + rep_s2, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + num_stages=2, + ) + else: + grid = (num_rows, rep_s2) + repeat_kernel_3d_tiled[grid]( + inp, + out, + inp_shape[0], + inp_shape[1], + inp_s2, + out_shape[1], + out_shape[2], + inp.stride(0), + inp.stride(1), + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + num_stages=4, + ) + + # Reshape back to original dimensionality + if orig_ndim == 2: + return out.reshape(out_shape[1], out_shape[2]) + return out diff --git a/src/flag_gems/runtime/backend/_iluvatar/ops/sub.py b/src/flag_gems/runtime/backend/_iluvatar/ops/sub.py new file mode 100644 index 0000000000..23be7ab217 --- /dev/null +++ b/src/flag_gems/runtime/backend/_iluvatar/ops/sub.py @@ -0,0 +1,244 @@ +import logging + +import torch +import triton +import triton.language as tl + +from flag_gems.utils import libentry + +logger = logging.getLogger(__name__) + + +@libentry() +@triton.jit +def sub_kernel( + x_ptr, + y_ptr, + output_ptr, + alpha, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + result = x - y * alpha + tl.store(output_ptr + offsets, result, mask=mask) + + +@libentry() +@triton.jit +def sub_scalar_kernel( + x_ptr, + y_scalar, + output_ptr, + alpha, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + x = tl.load(x_ptr + offsets, mask=mask) + result = x - y_scalar * alpha + tl.store(output_ptr + offsets, result, mask=mask) + + +@libentry() +@triton.jit +def sub_scalar_tensor_kernel( + x_scalar, + y_ptr, + output_ptr, + alpha, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + """Compute: x_scalar - y * alpha""" + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + y = tl.load(y_ptr + offsets, mask=mask) + result = x_scalar - y * alpha + tl.store(output_ptr + offsets, result, mask=mask) + + +@libentry() +@triton.jit +def sub__kernel( + x_ptr, + y_ptr, + alpha, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + """In-place subtraction: x = x - y * alpha""" + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + result = x - y * alpha + tl.store(x_ptr + offsets, result, mask=mask) + + +@libentry() +@triton.jit +def sub__scalar_kernel( + x_ptr, + y_scalar, + alpha, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + """In-place subtraction with scalar: x = x - y_scalar * alpha""" + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + x = tl.load(x_ptr + offsets, mask=mask) + result = x - y_scalar * alpha + tl.store(x_ptr + offsets, result, mask=mask) + + +def sub(A, B, *, alpha=1): + """Subtraction operator: output = A - B * alpha + + Supports: + - tensor - tensor + - tensor - scalar + - scalar - tensor + - scalar - scalar + """ + logger.debug("GEMS_ILUVATAR SUB") + + # Handle alpha + if isinstance(alpha, torch.Tensor): + alpha = alpha.item() + alpha = float(alpha) + + # Both are tensors + if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): + # Handle broadcasting + A, B = torch.broadcast_tensors(A, B) + A = A.contiguous() + B = B.contiguous() + n_elements = A.numel() + + # Empty tensor protection + if n_elements == 0: + return torch.empty_like(A) + + # Handle 0-dimensional tensor (scalar tensor) - use elementwise loop + if A.dim() == 0: + # For 0-dim tensors, just use Python computation + result = A.item() - B.item() * alpha + return torch.tensor(result, dtype=A.dtype, device=A.device) + + output = torch.empty_like(A) + BLOCK_SIZE = 2048 + grid = (triton.cdiv(n_elements, BLOCK_SIZE),) + sub_kernel[grid](A, B, output, alpha, n_elements, BLOCK_SIZE=BLOCK_SIZE) + return output + + # A is tensor, B is scalar + elif isinstance(A, torch.Tensor): + A = A.contiguous() + n_elements = A.numel() + + if n_elements == 0: + return torch.empty_like(A) + + # Handle 0-dimensional tensor + if A.dim() == 0: + result = A.item() - float(B) * alpha + return torch.tensor(result, dtype=A.dtype, device=A.device) + + output = torch.empty_like(A) + y_scalar = float(B) + BLOCK_SIZE = 2048 + grid = (triton.cdiv(n_elements, BLOCK_SIZE),) + sub_scalar_kernel[grid]( + A, y_scalar, output, alpha, n_elements, BLOCK_SIZE=BLOCK_SIZE + ) + return output + + # B is tensor, A is scalar + elif isinstance(B, torch.Tensor): + B = B.contiguous() + n_elements = B.numel() + + if n_elements == 0: + return torch.empty_like(B) + + # Handle 0-dimensional tensor + if B.dim() == 0: + result = float(A) - B.item() * alpha + return torch.tensor(result, dtype=B.dtype, device=B.device) + + output = torch.empty_like(B) + x_scalar = float(A) + BLOCK_SIZE = 2048 + grid = (triton.cdiv(n_elements, BLOCK_SIZE),) + sub_scalar_tensor_kernel[grid]( + x_scalar, B, output, alpha, n_elements, BLOCK_SIZE=BLOCK_SIZE + ) + return output + + # Both scalars + else: + return torch.tensor(A - B * alpha) + + +def sub_(A, B, *, alpha=1): + """In-place subtraction operator: A = A - B * alpha""" + logger.debug("GEMS_ILUVATAR SUB_") + + # Handle alpha + if isinstance(alpha, torch.Tensor): + alpha = alpha.item() + alpha = float(alpha) + + n_elements = A.numel() + + # Empty tensor protection + if n_elements == 0: + return A + + # Handle 0-dimensional tensor + if A.dim() == 0: + if isinstance(B, torch.Tensor): + result = A.item() - B.item() * alpha + else: + result = A.item() - float(B) * alpha + A.copy_(torch.tensor(result, dtype=A.dtype, device=A.device)) + return A + + BLOCK_SIZE = 2048 + + if isinstance(B, torch.Tensor): + # Handle broadcasting + if A.shape != B.shape: + B = B.expand_as(A) + B = B.contiguous() + A = A.contiguous() + grid = (triton.cdiv(n_elements, BLOCK_SIZE),) + sub__kernel[grid](A, B, alpha, n_elements, BLOCK_SIZE=BLOCK_SIZE) + else: + y_scalar = float(B) + A = A.contiguous() + grid = (triton.cdiv(n_elements, BLOCK_SIZE),) + sub__scalar_kernel[grid](A, y_scalar, alpha, n_elements, BLOCK_SIZE=BLOCK_SIZE) + + return A