-
Notifications
You must be signed in to change notification settings - Fork 308
[KernelGen] Add optimized add operator with 1.01x speedup #2167
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
de29fec
961b943
f1ef85e
8e1b07b
8fabcc0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,23 @@ | ||
| from .add import add, add_ | ||
| from .div import div_mode, div_mode_ | ||
| from .exponential_ import exponential_ | ||
| from .mul import mul | ||
| from .pow import pow_scalar, pow_scalar_ | ||
| from .sub import sub, sub_ | ||
| from .true_divide import true_divide, true_divide_, true_divide_out | ||
|
|
||
| __all__ = [ | ||
| "add", | ||
| "add_", | ||
| "div_mode", | ||
| "div_mode_", | ||
| "exponential_", | ||
| "mul", | ||
| "pow_scalar", | ||
| "pow_scalar_", | ||
| "sub", | ||
| "sub_", | ||
| "true_divide", | ||
| "true_divide_out", | ||
| "true_divide_", | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
| import logging | ||
|
|
||
| import torch | ||
| import triton | ||
|
|
||
| from flag_gems.utils import pointwise_dynamic | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| @pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")]) | ||
| @triton.jit | ||
| def add_func(x, y, alpha): | ||
| return x + y * alpha | ||
|
|
||
|
|
||
| @pointwise_dynamic( | ||
| is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")] | ||
| ) | ||
| @triton.jit | ||
| def add_func_tensor_scalar(x, y, alpha): | ||
| return x + y * alpha | ||
|
|
||
|
|
||
| @pointwise_dynamic( | ||
| is_tensor=[False, True, False], promotion_methods=[(0, 1, "DEFAULT")] | ||
| ) | ||
| @triton.jit | ||
| def add_func_scalar_tensor(x, y, alpha): | ||
| return x + y * alpha | ||
|
|
||
|
|
||
| def add(A, B, *, alpha=1): | ||
| logger.debug("GEMS ILUVATAR ADD") | ||
| if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): | ||
| if B.device != A.device: | ||
| B = B.to(A.device) | ||
| return add_func(A, B, alpha) | ||
| elif isinstance(A, torch.Tensor): | ||
| return add_func_tensor_scalar(A, B, alpha) | ||
| elif isinstance(B, torch.Tensor): | ||
| return add_func_scalar_tensor(A, B, alpha) | ||
| else: | ||
| return torch.tensor(A + B * alpha) | ||
|
|
||
|
|
||
| def add_(A, B, *, alpha=1): | ||
| logger.debug("GEMS ILUVATAR ADD_") | ||
| if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): | ||
| if B.device != A.device: | ||
| B = B.to(A.device) | ||
| return add_func(A, B, alpha, out0=A) | ||
| elif isinstance(A, torch.Tensor): | ||
| return add_func_tensor_scalar(A, B, alpha, out0=A) | ||
| # Note: scalar_tensor case not supported for in-place | ||
| else: | ||
| raise ValueError("Unreachable.") |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,205 @@ | ||
| import logging | ||
|
|
||
| import torch | ||
| import triton | ||
| import triton.language as tl | ||
|
|
||
| from flag_gems.runtime import torch_device_fn | ||
| from flag_gems.utils import libentry, libtuner | ||
| from flag_gems.utils.random_utils import ( | ||
| philox_backend_seed_offset, | ||
| uint_to_uniform_float, | ||
| ) | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| @triton.jit | ||
| def safe_fast_log_f32(x): | ||
| min_normal = (x * 0.0 + 1.17549435e-38).to(tl.float32) | ||
| max_u = x * 0.0 + 0.99999994 | ||
| x = tl.minimum(tl.maximum(x, min_normal), max_u) | ||
| bits = x.to(tl.int32, bitcast=True) | ||
| exponent = (bits >> 23) - 127 | ||
| mantissa = (bits & 0x7FFFFF).to(tl.float32) * (1.0 / 8388608.0) + 1.0 | ||
| m1 = mantissa - 1.0 | ||
| return ( | ||
| m1 * (1.0 + m1 * (-0.5 + m1 * (0.3333333333 - m1 * 0.25))) | ||
| + exponent.to(tl.float32) * 0.6931471805599453 | ||
| ) | ||
|
|
||
|
|
||
| @triton.jit | ||
| def safe_fast_log_f64(x): | ||
| min_normal = x * 0.0 + 2.2250738585072014e-308 | ||
| max_u = x * 0.0 + (1.0 - 2.220446049250313e-16) | ||
| x = tl.minimum(tl.maximum(x, min_normal), max_u) | ||
| bits = x.to(tl.int64, bitcast=True) | ||
| exponent = (bits >> 52) - 1023 | ||
| mantissa = (bits & 0x000FFFFFFFFFFFFF).to(tl.float64) * ( | ||
| 1.0 / 4503599627370496.0 | ||
| ) + 1.0 | ||
| m1 = mantissa - 1.0 | ||
| return ( | ||
| m1 * (1.0 + m1 * (-0.5 + m1 * (0.3333333333333333 - m1 * 0.25))) | ||
| + exponent.to(tl.float64) * 0.6931471805599453 | ||
| ) | ||
|
|
||
|
|
||
| @triton.jit | ||
| def paste_u64(hi: tl.uint32, lo: tl.uint32): | ||
| return (hi.to(tl.uint64) << 32) | lo.to(tl.uint64) | ||
|
|
||
|
|
||
| @triton.jit | ||
| def transform_exponential_f32_precise(u, inv_lambd, eps_minus): | ||
| log = tl.where(u >= 1.0 + eps_minus, eps_minus, tl.math.log(u)) | ||
| return -inv_lambd * log | ||
|
|
||
|
|
||
| @triton.jit | ||
| def transform_exponential_f32_fast(u, inv_lambd, eps_minus): | ||
| log = tl.where(u >= 1.0 + eps_minus, eps_minus, safe_fast_log_f32(u)) | ||
| return -inv_lambd * log | ||
|
|
||
|
|
||
| # Iluvatar uses the precise version for numerical stability | ||
| transform_exponential_f32 = transform_exponential_f32_precise | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a bad programming style... |
||
|
|
||
|
|
||
| @triton.jit | ||
| def transform_exponential_f64(u, inv_lambd, eps_minus): | ||
| log = tl.where(u >= 1.0 + eps_minus, eps_minus, safe_fast_log_f64(u)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since |
||
| return -inv_lambd * log | ||
|
|
||
|
|
||
| @libentry() | ||
| @libtuner( | ||
| configs=[ | ||
| triton.Config({"BLOCK": 64}, num_warps=2, num_stages=2), | ||
| triton.Config({"BLOCK": 128}, num_warps=2, num_stages=2), | ||
| triton.Config({"BLOCK": 256}, num_warps=4, num_stages=2), | ||
| triton.Config({"BLOCK": 512}, num_warps=4, num_stages=3), | ||
| triton.Config({"BLOCK": 1024}, num_warps=4, num_stages=2), | ||
| triton.Config({"BLOCK": 2048}, num_warps=8, num_stages=2), | ||
| ], | ||
| key=["N"], | ||
| ) | ||
| @triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"]) | ||
| def fused_exponential_kernel_f32( | ||
| out_ptr, N, inv_lambd, eps_minus, philox_seed, philox_offset, BLOCK: tl.constexpr | ||
| ): | ||
| philox_seed = philox_seed.to(tl.int64) | ||
| philox_offset = philox_offset.to(tl.int64) | ||
| c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) | ||
| c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) | ||
|
|
||
| pid = tl.program_id(0) | ||
| i = pid * BLOCK + tl.arange(0, BLOCK) | ||
| c0 += i | ||
| z = c0 * 0 | ||
| r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, z, z) | ||
|
|
||
| y0 = transform_exponential_f32(uint_to_uniform_float(r0), inv_lambd, eps_minus) | ||
| y1 = transform_exponential_f32(uint_to_uniform_float(r1), inv_lambd, eps_minus) | ||
| y2 = transform_exponential_f32(uint_to_uniform_float(r2), inv_lambd, eps_minus) | ||
| y3 = transform_exponential_f32(uint_to_uniform_float(r3), inv_lambd, eps_minus) | ||
|
|
||
| start = pid.to(tl.uint64) * BLOCK * 4 | ||
| off0 = start + tl.arange(0, BLOCK) | ||
| off1 = off0 + BLOCK | ||
| off2 = off1 + BLOCK | ||
| off3 = off2 + BLOCK | ||
|
|
||
| tl.store(out_ptr + off0, y0, mask=off0 < N) | ||
| tl.store(out_ptr + off1, y1, mask=off1 < N) | ||
| tl.store(out_ptr + off2, y2, mask=off2 < N) | ||
| tl.store(out_ptr + off3, y3, mask=off3 < N) | ||
|
|
||
|
|
||
| @libentry() | ||
| @libtuner( | ||
| configs=[ | ||
| triton.Config({"BLOCK": 64}, num_warps=2, num_stages=2), | ||
| triton.Config({"BLOCK": 128}, num_warps=2, num_stages=2), | ||
| triton.Config({"BLOCK": 256}, num_warps=4, num_stages=2), | ||
| triton.Config({"BLOCK": 512}, num_warps=4, num_stages=3), | ||
| triton.Config({"BLOCK": 1024}, num_warps=4, num_stages=2), | ||
| triton.Config({"BLOCK": 2048}, num_warps=8, num_stages=2), | ||
| ], | ||
| key=["N"], | ||
| ) | ||
| @triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"]) | ||
| def fused_exponential_kernel_f64( | ||
| out_ptr, N, inv_lambd, eps_minus, philox_seed, philox_offset, BLOCK: tl.constexpr | ||
| ): | ||
| philox_seed = philox_seed.to(tl.int64) | ||
| philox_offset = philox_offset.to(tl.int64) | ||
| c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) | ||
| c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) | ||
|
|
||
| pid = tl.program_id(0) | ||
| i = pid * BLOCK + tl.arange(0, BLOCK) | ||
| c0 += i | ||
| z = c0 * 0 | ||
| r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, z, z) | ||
|
|
||
| u0 = uint_to_uniform_float(paste_u64(r0, r2)) | ||
| u1 = uint_to_uniform_float(paste_u64(r1, r3)) | ||
|
|
||
| y0 = transform_exponential_f64(u0, inv_lambd, eps_minus) | ||
| y1 = transform_exponential_f64(u1, inv_lambd, eps_minus) | ||
|
|
||
| start = pid.to(tl.uint64) * BLOCK * 2 | ||
| off0 = start + tl.arange(0, BLOCK) | ||
| off1 = off0 + BLOCK | ||
|
|
||
| tl.store(out_ptr + off0, y0, mask=off0 < N) | ||
| tl.store(out_ptr + off1, y1, mask=off1 < N) | ||
|
|
||
|
|
||
| def exponential_(x, lambd: float = 1.0, *, generator=None): | ||
| logger.debug("GEMS_ILUVATAR EXPONENTIAL_") | ||
|
|
||
| dtype = x.dtype | ||
| device = x.device | ||
| inplace = x.is_contiguous() | ||
| assert dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64) | ||
|
|
||
| N = x.numel() | ||
|
|
||
| # Handle empty tensor | ||
| if N == 0: | ||
| return x | ||
|
|
||
| inv_lambd = 1.0 / lambd | ||
| eps_minus = -0.5 * torch.finfo(dtype).eps | ||
|
|
||
| out = x if inplace else torch.empty_like(x) | ||
|
|
||
| if dtype is torch.float64: | ||
| UNROLL = 2 | ||
| grid = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),) | ||
| increment = triton.cdiv(N, UNROLL) | ||
| philox_seed, philox_offset = philox_backend_seed_offset( | ||
| increment, generator=generator | ||
| ) | ||
| with torch_device_fn.device(device): | ||
| fused_exponential_kernel_f64[grid]( | ||
| out, N, inv_lambd, eps_minus, philox_seed, philox_offset | ||
| ) | ||
| else: | ||
| UNROLL = 4 | ||
| grid = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),) | ||
| increment = triton.cdiv(N, UNROLL) | ||
| philox_seed, philox_offset = philox_backend_seed_offset( | ||
| increment, generator=generator | ||
| ) | ||
| with torch_device_fn.device(device): | ||
| fused_exponential_kernel_f32[grid]( | ||
| out, N, inv_lambd, eps_minus, philox_seed, philox_offset | ||
| ) | ||
|
|
||
| if not inplace: | ||
| x.copy_(out) | ||
| return x | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,47 @@ | ||
| import logging | ||
|
|
||
| import triton | ||
| import triton.language as tl | ||
|
|
||
| from flag_gems.utils import pointwise_dynamic, tl_extra_shim | ||
|
|
||
| _pow = tl_extra_shim.pow | ||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| @pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "BOOL_TO_LONG")]) | ||
| @triton.jit | ||
| def pow_func_scalar_tensor(x, exponent): | ||
| return _pow(x.to(tl.float32), exponent.to(tl.float32)) | ||
|
|
||
|
|
||
| def pow_scalar(A, exponent): | ||
| """ | ||
| Computes base^exponent where base is a scalar and exponent is a tensor. | ||
|
|
||
| Uses FlagGems standard pointwise_dynamic for hardware compatibility. | ||
|
|
||
| Args: | ||
| A: Scalar base value | ||
| exponent: Exponent tensor | ||
|
|
||
| Returns: | ||
| Output tensor with same shape as exponent | ||
| """ | ||
| logger.debug("GEMS_ILUVATAR POW_SCALAR") | ||
| return pow_func_scalar_tensor(A, exponent) | ||
|
|
||
|
|
||
| def pow_scalar_(A, exponent): | ||
| """ | ||
| In-place version of pow_scalar. | ||
|
|
||
| Args: | ||
| A: Scalar base value | ||
| exponent: Exponent tensor (modified in-place) | ||
|
|
||
| Returns: | ||
| The modified exponent tensor | ||
| """ | ||
| logger.debug("GEMS_ILUVATAR POW_SCALAR_") | ||
| return pow_func_scalar_tensor(A, exponent, out0=exponent) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Useless function?