-
Notifications
You must be signed in to change notification settings - Fork 310
[PerfXLab] optimize fill performance #2216
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,14 @@ | ||
| import triton | ||
|
|
||
| if triton.__version__ >= "3.4": | ||
| from .fill import ( # noqa: F401 | ||
| fill_scalar, | ||
| fill_scalar_, | ||
| fill_scalar_out, | ||
| fill_tensor, | ||
| fill_tensor_, | ||
| fill_tensor_out, | ||
| ) | ||
| from .mm import mm, mm_out # noqa: F401 | ||
|
|
||
| __all__ = ["*"] |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,121 @@ | ||||||
| import logging | ||||||
|
|
||||||
| import torch | ||||||
| import triton | ||||||
| import triton.language as tl | ||||||
|
|
||||||
| from flag_gems.runtime import torch_device_fn | ||||||
|
|
||||||
| logger = logging.getLogger(__name__) | ||||||
|
|
||||||
|
|
||||||
| @triton.jit | ||||||
| def fill_scalar_kernel( | ||||||
| out_ptr, | ||||||
| value_scalar, | ||||||
| n_elements, | ||||||
| BLOCK_SIZE: tl.constexpr, | ||||||
| ): | ||||||
| pid = tl.program_id(axis=0) | ||||||
| block_start = pid * BLOCK_SIZE | ||||||
| offsets = block_start + tl.arange(0, BLOCK_SIZE) | ||||||
| mask = offsets < n_elements | ||||||
|
|
||||||
| # Load a dummy value to infer the dtype of out_ptr | ||||||
| dummy = tl.load(out_ptr + offsets, mask=mask, other=0) | ||||||
| fill_val = tl.full([BLOCK_SIZE], value_scalar, dtype=dummy.dtype) | ||||||
| tl.store(out_ptr + offsets, fill_val, mask=mask) | ||||||
|
|
||||||
|
|
||||||
| @triton.jit | ||||||
| def fill_tensor_kernel( | ||||||
| out_ptr, | ||||||
| value_ptr, | ||||||
| n_elements, | ||||||
| BLOCK_SIZE: tl.constexpr, | ||||||
| ): | ||||||
| pid = tl.program_id(axis=0) | ||||||
| block_start = pid * BLOCK_SIZE | ||||||
| offsets = block_start + tl.arange(0, BLOCK_SIZE) | ||||||
| mask = offsets < n_elements | ||||||
|
|
||||||
| val = tl.load(value_ptr) | ||||||
| tl.store(out_ptr + offsets, val, mask=mask) | ||||||
|
|
||||||
|
|
||||||
| def fill_scalar(input, value): | ||||||
| logger.debug("GEMS FILL_SCALAR HOPPER") | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. GEMS_HOPPER FILL_SCALAR |
||||||
| out = torch.empty_like(input) | ||||||
| n_elements = out.numel() | ||||||
| grid = (triton.cdiv(n_elements, 1024),) | ||||||
| with torch_device_fn.device(input.device): | ||||||
| fill_scalar_kernel[grid](out, value, n_elements, BLOCK_SIZE=1024) | ||||||
| return out | ||||||
|
|
||||||
|
|
||||||
| def fill_scalar_out(input, value, *, out=None): | ||||||
| logger.debug("GEMS FILL_SCALAR_OUT HOPPER") | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. GEMS_HOPPER FILL_SCALAR_OUT |
||||||
| if out is None: | ||||||
| return fill_scalar(input, value) | ||||||
| n_elements = out.numel() | ||||||
| grid = (triton.cdiv(n_elements, 1024),) | ||||||
| with torch_device_fn.device(input.device): | ||||||
| fill_scalar_kernel[grid](out, value, n_elements, BLOCK_SIZE=1024) | ||||||
| return out | ||||||
|
|
||||||
|
|
||||||
| def fill_tensor(input, value): | ||||||
| if not value.is_cuda: | ||||||
| return fill_scalar(input, value.item()) | ||||||
| logger.debug("GEMS FILL_TENSOR HOPPER") | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| if value.ndim != 0: | ||||||
| raise RuntimeError( | ||||||
| f"fill only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions." | ||||||
| ) | ||||||
| out = torch.empty_like(input) | ||||||
| n_elements = out.numel() | ||||||
| grid = (triton.cdiv(n_elements, 1024),) | ||||||
| with torch_device_fn.device(input.device): | ||||||
| fill_tensor_kernel[grid](out, value, n_elements, BLOCK_SIZE=1024) | ||||||
| return out | ||||||
|
|
||||||
|
|
||||||
| def fill_tensor_out(input, value, *, out=None): | ||||||
| logger.debug("GEMS FILL_TENSOR_OUT HOPPER") | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| if out is None: | ||||||
| return fill_tensor(input, value) | ||||||
| if not value.is_cuda: | ||||||
| return fill_scalar_out(input, value.item(), out=out) | ||||||
| if value.ndim != 0: | ||||||
| raise RuntimeError( | ||||||
| f"fill only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions." | ||||||
| ) | ||||||
| n_elements = out.numel() | ||||||
| grid = (triton.cdiv(n_elements, 1024),) | ||||||
| with torch_device_fn.device(input.device): | ||||||
| fill_tensor_kernel[grid](out, value, n_elements, BLOCK_SIZE=1024) | ||||||
| return out | ||||||
|
|
||||||
|
|
||||||
| def fill_tensor_(self, value): | ||||||
| if not value.is_cuda: | ||||||
| return fill_scalar_(self, value.item()) | ||||||
| logger.debug("GEMS FILL_TENSOR_ HOPPER") | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| if value.ndim != 0: | ||||||
| raise RuntimeError( | ||||||
| f"fill only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions." | ||||||
| ) | ||||||
| n_elements = self.numel() | ||||||
| grid = (triton.cdiv(n_elements, 1024),) | ||||||
| with torch_device_fn.device(self.device): | ||||||
| fill_tensor_kernel[grid](self, value, n_elements, BLOCK_SIZE=1024) | ||||||
| return self | ||||||
|
|
||||||
|
|
||||||
| def fill_scalar_(self, value): | ||||||
| logger.debug("GEMS FILL_SCALAR_ HOPPER") | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| n_elements = self.numel() | ||||||
| grid = (triton.cdiv(n_elements, 1024),) | ||||||
| with torch_device_fn.device(self.device): | ||||||
| fill_scalar_kernel[grid](self, value, n_elements, BLOCK_SIZE=1024) | ||||||
| return self | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is the FlagGems benchmark for
fill_scalar_commented out?