-
Notifications
You must be signed in to change notification settings - Fork 308
[KernelGen] Add optimized ones operator for Iluvatar platform #2186
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
zacliu2023
wants to merge
10
commits into
flagos-ai:master
Choose a base branch
from
zacliu2023:kernelgen2.0-tianshu-ones
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
de29fec
[kernelgen2.0] Add exponential_ operator for Iluvatar platform
ftgreat 961b943
[kernelgen2.0] Add pow_scalar operator for Iluvatar platform
ftgreat f1ef85e
[kernelgen2.0] Add sub operator for Iluvatar platform
ftgreat 8e1b07b
[kernelgen2.0] Add optimized add operator for Iluvatar platform
ftgreat dd83308
[kernelgen2.0] Add optimized repeat operator for Iluvatar platform
ftgreat c5d4b1a
[kernelgen2.0] Register repeat and ones operators in _iluvatar backend
ftgreat 34c1dff
[kernelgen2.0] Add optimized ones operator for Iluvatar platform
ftgreat 00a688e
Fix flake8, isort, and black lint errors
zacliu2023 0be8855
Fix black formatting in repeat.py - one param per line
zacliu2023 b1ae94e
Fix F841: remove unused variable orig_sizes in repeat.py
zacliu2023 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,21 @@ | ||
| from .add import add, add_ | ||
| from .div import div_mode, div_mode_ | ||
| from .exponential_ import exponential_ | ||
| from .ones import ones | ||
| from .pow import pow_scalar, pow_scalar_ | ||
| from .repeat import repeat | ||
| from .sub import sub, sub_ | ||
|
|
||
| __all__ = [ | ||
| "add", | ||
| "add_", | ||
| "div_mode", | ||
| "div_mode_", | ||
| "exponential_", | ||
| "ones", | ||
| "pow_scalar", | ||
| "pow_scalar_", | ||
| "repeat", | ||
| "sub", | ||
| "sub_", | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,57 @@ | ||||||
| import logging | ||||||
|
|
||||||
| import torch | ||||||
| import triton | ||||||
|
|
||||||
| from flag_gems.utils import pointwise_dynamic | ||||||
|
|
||||||
| logger = logging.getLogger(__name__) | ||||||
|
|
||||||
|
|
||||||
| @pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")]) | ||||||
| @triton.jit | ||||||
| def add_func(x, y, alpha): | ||||||
| return x + y * alpha | ||||||
|
|
||||||
|
|
||||||
| @pointwise_dynamic( | ||||||
| is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")] | ||||||
| ) | ||||||
| @triton.jit | ||||||
| def add_func_tensor_scalar(x, y, alpha): | ||||||
| return x + y * alpha | ||||||
|
|
||||||
|
|
||||||
| @pointwise_dynamic( | ||||||
| is_tensor=[False, True, False], promotion_methods=[(0, 1, "DEFAULT")] | ||||||
| ) | ||||||
| @triton.jit | ||||||
| def add_func_scalar_tensor(x, y, alpha): | ||||||
| return x + y * alpha | ||||||
|
|
||||||
|
|
||||||
| def add(A, B, *, alpha=1): | ||||||
| logger.debug("GEMS ILUVATAR ADD") | ||||||
| if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): | ||||||
| if B.device != A.device: | ||||||
| B = B.to(A.device) | ||||||
| return add_func(A, B, alpha) | ||||||
| elif isinstance(A, torch.Tensor): | ||||||
| return add_func_tensor_scalar(A, B, alpha) | ||||||
| elif isinstance(B, torch.Tensor): | ||||||
| return add_func_scalar_tensor(A, B, alpha) | ||||||
| else: | ||||||
| return torch.tensor(A + B * alpha) | ||||||
|
|
||||||
|
|
||||||
| def add_(A, B, *, alpha=1): | ||||||
| logger.debug("GEMS ILUVATAR ADD_") | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): | ||||||
| if B.device != A.device: | ||||||
| B = B.to(A.device) | ||||||
| return add_func(A, B, alpha, out0=A) | ||||||
| elif isinstance(A, torch.Tensor): | ||||||
| return add_func_tensor_scalar(A, B, alpha, out0=A) | ||||||
| # Note: scalar_tensor case not supported for in-place | ||||||
| else: | ||||||
| raise ValueError("Unreachable.") | ||||||
205 changes: 205 additions & 0 deletions
205
src/flag_gems/runtime/backend/_iluvatar/ops/exponential_.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,205 @@ | ||
| import logging | ||
|
|
||
| import torch | ||
| import triton | ||
| import triton.language as tl | ||
|
|
||
| from flag_gems.runtime import torch_device_fn | ||
| from flag_gems.utils import libentry, libtuner | ||
| from flag_gems.utils.random_utils import ( | ||
| philox_backend_seed_offset, | ||
| uint_to_uniform_float, | ||
| ) | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| @triton.jit | ||
| def safe_fast_log_f32(x): | ||
| min_normal = (x * 0.0 + 1.17549435e-38).to(tl.float32) | ||
| max_u = x * 0.0 + 0.99999994 | ||
| x = tl.minimum(tl.maximum(x, min_normal), max_u) | ||
| bits = x.to(tl.int32, bitcast=True) | ||
| exponent = (bits >> 23) - 127 | ||
| mantissa = (bits & 0x7FFFFF).to(tl.float32) * (1.0 / 8388608.0) + 1.0 | ||
| m1 = mantissa - 1.0 | ||
| return ( | ||
| m1 * (1.0 + m1 * (-0.5 + m1 * (0.3333333333 - m1 * 0.25))) | ||
| + exponent.to(tl.float32) * 0.6931471805599453 | ||
| ) | ||
|
|
||
|
|
||
| @triton.jit | ||
| def safe_fast_log_f64(x): | ||
| min_normal = x * 0.0 + 2.2250738585072014e-308 | ||
| max_u = x * 0.0 + (1.0 - 2.220446049250313e-16) | ||
| x = tl.minimum(tl.maximum(x, min_normal), max_u) | ||
| bits = x.to(tl.int64, bitcast=True) | ||
| exponent = (bits >> 52) - 1023 | ||
| mantissa = (bits & 0x000FFFFFFFFFFFFF).to(tl.float64) * ( | ||
| 1.0 / 4503599627370496.0 | ||
| ) + 1.0 | ||
| m1 = mantissa - 1.0 | ||
| return ( | ||
| m1 * (1.0 + m1 * (-0.5 + m1 * (0.3333333333333333 - m1 * 0.25))) | ||
| + exponent.to(tl.float64) * 0.6931471805599453 | ||
| ) | ||
|
|
||
|
|
||
| @triton.jit | ||
| def paste_u64(hi: tl.uint32, lo: tl.uint32): | ||
| return (hi.to(tl.uint64) << 32) | lo.to(tl.uint64) | ||
|
|
||
|
|
||
| @triton.jit | ||
| def transform_exponential_f32_precise(u, inv_lambd, eps_minus): | ||
| log = tl.where(u >= 1.0 + eps_minus, eps_minus, tl.math.log(u)) | ||
| return -inv_lambd * log | ||
|
|
||
|
|
||
| @triton.jit | ||
| def transform_exponential_f32_fast(u, inv_lambd, eps_minus): | ||
| log = tl.where(u >= 1.0 + eps_minus, eps_minus, safe_fast_log_f32(u)) | ||
| return -inv_lambd * log | ||
|
|
||
|
|
||
| # Iluvatar uses the precise version for numerical stability | ||
| transform_exponential_f32 = transform_exponential_f32_precise | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't rename functions in this style, it is easy to get things messy. |
||
|
|
||
| @triton.jit | ||
| def transform_exponential_f64(u, inv_lambd, eps_minus): | ||
| log = tl.where(u >= 1.0 + eps_minus, eps_minus, safe_fast_log_f64(u)) | ||
| return -inv_lambd * log | ||
|
|
||
|
|
||
| @libentry() | ||
| @libtuner( | ||
| configs=[ | ||
| triton.Config({"BLOCK": 64}, num_warps=2, num_stages=2), | ||
| triton.Config({"BLOCK": 128}, num_warps=2, num_stages=2), | ||
| triton.Config({"BLOCK": 256}, num_warps=4, num_stages=2), | ||
| triton.Config({"BLOCK": 512}, num_warps=4, num_stages=3), | ||
| triton.Config({"BLOCK": 1024}, num_warps=4, num_stages=2), | ||
| triton.Config({"BLOCK": 2048}, num_warps=8, num_stages=2), | ||
| ], | ||
| key=["N"], | ||
| ) | ||
| @triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"]) | ||
| def fused_exponential_kernel_f32( | ||
| out_ptr, N, inv_lambd, eps_minus, philox_seed, philox_offset, BLOCK: tl.constexpr | ||
| ): | ||
| philox_seed = philox_seed.to(tl.int64) | ||
| philox_offset = philox_offset.to(tl.int64) | ||
| c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) | ||
| c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) | ||
|
|
||
| pid = tl.program_id(0) | ||
| i = pid * BLOCK + tl.arange(0, BLOCK) | ||
| c0 += i | ||
| z = c0 * 0 | ||
| r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, z, z) | ||
|
|
||
| y0 = transform_exponential_f32(uint_to_uniform_float(r0), inv_lambd, eps_minus) | ||
| y1 = transform_exponential_f32(uint_to_uniform_float(r1), inv_lambd, eps_minus) | ||
| y2 = transform_exponential_f32(uint_to_uniform_float(r2), inv_lambd, eps_minus) | ||
| y3 = transform_exponential_f32(uint_to_uniform_float(r3), inv_lambd, eps_minus) | ||
|
|
||
| start = pid.to(tl.uint64) * BLOCK * 4 | ||
| off0 = start + tl.arange(0, BLOCK) | ||
| off1 = off0 + BLOCK | ||
| off2 = off1 + BLOCK | ||
| off3 = off2 + BLOCK | ||
|
|
||
| tl.store(out_ptr + off0, y0, mask=off0 < N) | ||
| tl.store(out_ptr + off1, y1, mask=off1 < N) | ||
| tl.store(out_ptr + off2, y2, mask=off2 < N) | ||
| tl.store(out_ptr + off3, y3, mask=off3 < N) | ||
|
|
||
|
|
||
| @libentry() | ||
| @libtuner( | ||
| configs=[ | ||
| triton.Config({"BLOCK": 64}, num_warps=2, num_stages=2), | ||
| triton.Config({"BLOCK": 128}, num_warps=2, num_stages=2), | ||
| triton.Config({"BLOCK": 256}, num_warps=4, num_stages=2), | ||
| triton.Config({"BLOCK": 512}, num_warps=4, num_stages=3), | ||
| triton.Config({"BLOCK": 1024}, num_warps=4, num_stages=2), | ||
| triton.Config({"BLOCK": 2048}, num_warps=8, num_stages=2), | ||
| ], | ||
| key=["N"], | ||
| ) | ||
| @triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"]) | ||
| def fused_exponential_kernel_f64( | ||
| out_ptr, N, inv_lambd, eps_minus, philox_seed, philox_offset, BLOCK: tl.constexpr | ||
| ): | ||
| philox_seed = philox_seed.to(tl.int64) | ||
| philox_offset = philox_offset.to(tl.int64) | ||
| c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) | ||
| c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) | ||
|
|
||
| pid = tl.program_id(0) | ||
| i = pid * BLOCK + tl.arange(0, BLOCK) | ||
| c0 += i | ||
| z = c0 * 0 | ||
| r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, z, z) | ||
|
|
||
| u0 = uint_to_uniform_float(paste_u64(r0, r2)) | ||
| u1 = uint_to_uniform_float(paste_u64(r1, r3)) | ||
|
|
||
| y0 = transform_exponential_f64(u0, inv_lambd, eps_minus) | ||
| y1 = transform_exponential_f64(u1, inv_lambd, eps_minus) | ||
|
|
||
| start = pid.to(tl.uint64) * BLOCK * 2 | ||
| off0 = start + tl.arange(0, BLOCK) | ||
| off1 = off0 + BLOCK | ||
|
|
||
| tl.store(out_ptr + off0, y0, mask=off0 < N) | ||
| tl.store(out_ptr + off1, y1, mask=off1 < N) | ||
|
|
||
|
|
||
| def exponential_(x, lambd: float = 1.0, *, generator=None): | ||
| logger.debug("GEMS_ILUVATAR EXPONENTIAL_") | ||
|
|
||
| dtype = x.dtype | ||
| device = x.device | ||
| inplace = x.is_contiguous() | ||
| assert dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64) | ||
|
|
||
| N = x.numel() | ||
|
|
||
| # Handle empty tensor | ||
| if N == 0: | ||
| return x | ||
|
|
||
| inv_lambd = 1.0 / lambd | ||
| eps_minus = -0.5 * torch.finfo(dtype).eps | ||
|
|
||
| out = x if inplace else torch.empty_like(x) | ||
|
|
||
| if dtype is torch.float64: | ||
| UNROLL = 2 | ||
| grid = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),) | ||
| increment = triton.cdiv(N, UNROLL) | ||
| philox_seed, philox_offset = philox_backend_seed_offset( | ||
| increment, generator=generator | ||
| ) | ||
| with torch_device_fn.device(device): | ||
| fused_exponential_kernel_f64[grid]( | ||
| out, N, inv_lambd, eps_minus, philox_seed, philox_offset | ||
| ) | ||
| else: | ||
| UNROLL = 4 | ||
| grid = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),) | ||
| increment = triton.cdiv(N, UNROLL) | ||
| philox_seed, philox_offset = philox_backend_seed_offset( | ||
| increment, generator=generator | ||
| ) | ||
| with torch_device_fn.device(device): | ||
| fused_exponential_kernel_f32[grid]( | ||
| out, N, inv_lambd, eps_minus, philox_seed, philox_offset | ||
| ) | ||
|
|
||
| if not inplace: | ||
| x.copy_(out) | ||
| return x | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,45 @@ | ||
| import logging | ||
|
|
||
| import torch | ||
| import triton | ||
| import triton.language as tl | ||
|
|
||
| from flag_gems.runtime import device, torch_device_fn | ||
| from flag_gems.utils import libentry | ||
| from flag_gems.utils.shape_utils import volume | ||
|
|
||
| device_ = device | ||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| @libentry() | ||
| @triton.jit | ||
| def ones_kernel( | ||
| output_ptr, | ||
| n_elements, | ||
| BLOCK_SIZE: tl.constexpr, | ||
| ): | ||
| pid = tl.program_id(0) | ||
| block_start = pid * BLOCK_SIZE | ||
| offsets = block_start + tl.arange(0, BLOCK_SIZE) | ||
| mask = offsets < n_elements | ||
| tl.store(output_ptr + offsets, 1.0, mask=mask) | ||
|
|
||
|
|
||
| def ones(size, *, dtype=None, layout=None, device=None, pin_memory=None): | ||
| logger.debug("GEMS_ILUVATAR ONES") | ||
| if dtype is None: | ||
| dtype = torch.get_default_dtype() | ||
| if device is None: | ||
| device = torch.device(device_.name) | ||
|
|
||
| out = torch.empty(size, device=device, dtype=dtype) | ||
| N = volume(size) | ||
| if N == 0: | ||
| return out | ||
|
|
||
| BLOCK_SIZE = 2048 | ||
| grid = (triton.cdiv(N, BLOCK_SIZE),) | ||
| with torch_device_fn.device(device): | ||
| ones_kernel[grid](out, N, BLOCK_SIZE) | ||
| return out |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,47 @@ | ||
| import logging | ||
|
|
||
| import triton | ||
| import triton.language as tl | ||
|
|
||
| from flag_gems.utils import pointwise_dynamic, tl_extra_shim | ||
|
|
||
| _pow = tl_extra_shim.pow | ||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| @pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "BOOL_TO_LONG")]) | ||
| @triton.jit | ||
| def pow_func_scalar_tensor(x, exponent): | ||
| return _pow(x.to(tl.float32), exponent.to(tl.float32)) | ||
|
|
||
|
|
||
| def pow_scalar(A, exponent): | ||
| """ | ||
| Computes base^exponent where base is a scalar and exponent is a tensor. | ||
|
|
||
| Uses FlagGems standard pointwise_dynamic for hardware compatibility. | ||
|
|
||
| Args: | ||
| A: Scalar base value | ||
| exponent: Exponent tensor | ||
|
|
||
| Returns: | ||
| Output tensor with same shape as exponent | ||
| """ | ||
| logger.debug("GEMS_ILUVATAR POW_SCALAR") | ||
| return pow_func_scalar_tensor(A, exponent) | ||
|
|
||
|
|
||
| def pow_scalar_(A, exponent): | ||
| """ | ||
| In-place version of pow_scalar. | ||
|
|
||
| Args: | ||
| A: Scalar base value | ||
| exponent: Exponent tensor (modified in-place) | ||
|
|
||
| Returns: | ||
| The modified exponent tensor | ||
| """ | ||
| logger.debug("GEMS_ILUVATAR POW_SCALAR_") | ||
| return pow_func_scalar_tensor(A, exponent, out0=exponent) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.