-
Notifications
You must be signed in to change notification settings - Fork 593
refactor: backend_requirement + supported_compute_capability decorator for gemm #2000
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
604cfd4
499dcc5
ad39f67
ebb610c
bb6b620
6a962ef
3c07921
b9287c9
a5ff033
de4c701
1181c5d
f9cd034
5854494
da01b1b
1e75bff
2d68a6b
d528f0c
f2cc526
e1c1e2a
9bc5bd5
2580610
579012b
6d19a75
9721ff7
747b4e2
adb0e89
b211926
26d587a
aacc8df
f25929f
55ea787
63cf562
adcc5dd
f566d49
36d2463
3cb8f9a
20435b4
f588d96
c8f2b03
e450c7d
ba011d1
74281ed
d56748f
8d7d0bc
f5a06a4
d42fb90
fbdb439
11177e8
eccbdde
96e73b8
53a6da4
abf6a14
6765cad
b433fc7
54101e9
9a79b78
636a3ab
37434ed
ba8f3ed
cce4952
4ddf71d
d42b71f
4aed50c
0a36050
3b07247
a9f71bd
875403e
1c4b522
219592b
b9964cc
3a23405
0753095
76eea79
af25b45
049e8db
2628beb
0aee7af
7128c7b
5acb57b
5e11004
84df81e
2439a41
d56be0d
cf2df82
9f13e83
ecd4ef1
efd8554
fd5273c
aeeccac
1940b28
d0d99d2
18004a8
df5c2e4
b14408b
dc37789
e59226b
23ff744
89e1adb
b56bb4f
4efb7bb
685db69
9ac59e5
442dec9
1e15fed
890bb46
cf6962a
40bc6e1
cdc5fb7
eec483b
cc50469
54c1678
4db4ac0
6930085
1733727
b972005
88328e8
ad3f26b
09872a1
70bc2b5
6dfc1ba
185d63a
5fe01a2
abcd8e0
6bb01d1
9f1cb89
dc0ade7
3a29167
fae79ab
f2b9cc0
8f4e806
25fe3eb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -45,7 +45,12 @@ | |||||||
| from .cuda_utils import checkCudaErrors | ||||||||
| from .jit.cubin_loader import get_cubin | ||||||||
| from .jit.env import FLASHINFER_CUBIN_DIR | ||||||||
| from .utils import ceil_div, round_up | ||||||||
| from .utils import ( | ||||||||
| ceil_div, | ||||||||
| round_up, | ||||||||
| supported_compute_capability, | ||||||||
| backend_requirement, | ||||||||
| ) | ||||||||
|
|
||||||||
|
|
||||||||
| class GemmType(enum.Enum): | ||||||||
|
|
@@ -1358,24 +1363,27 @@ def m_grouped_fp8_gemm_nt_masked_sm10x( | |||||||
| runtime(**all_kwargs) | ||||||||
|
|
||||||||
|
|
||||||||
| def m_grouped_fp8_gemm_nt_contiguous( | ||||||||
| @supported_compute_capability([100, 103]) | ||||||||
| def _check_group_deepgemm_fp8_nt_contiguous_problem_size( | ||||||||
| a_fp8: Tuple[torch.Tensor, torch.Tensor], | ||||||||
| b_fp8: Tuple[torch.Tensor, torch.Tensor], | ||||||||
| d: torch.Tensor, | ||||||||
| m_indices: torch.Tensor, | ||||||||
| recipe: Optional[Tuple[int, int, int]] = None, | ||||||||
| compiled_dims: str = "nk", | ||||||||
| ) -> None: | ||||||||
| # Compiled dims can be upper cases | ||||||||
| compiled_dims = compiled_dims.lower() | ||||||||
|
|
||||||||
| ) -> bool: | ||||||||
| # NOTES: shape must be `[M, K] @ [G, N, K].mT` | ||||||||
| major_a = get_major_type_ab(a_fp8[0]) | ||||||||
| major_b = get_major_type_ab(b_fp8[0]) | ||||||||
| assert major_a == MajorTypeAB.KMajor | ||||||||
| if must_be_k_major(): | ||||||||
| assert major_b == MajorTypeAB.KMajor | ||||||||
| assert m_indices.is_contiguous() | ||||||||
| if major_a != MajorTypeAB.KMajor: | ||||||||
| raise ValueError(f"major_a must be KMajor, but got {major_a}") | ||||||||
| if must_be_k_major() and (major_b != MajorTypeAB.KMajor): | ||||||||
| raise ValueError(f"major_b must be KMajor, but got {major_b}") | ||||||||
|
|
||||||||
| if not m_indices.is_contiguous(): | ||||||||
| raise ValueError( | ||||||||
| f"m_indices must be contiguous, but got {m_indices.is_contiguous()}" | ||||||||
| ) | ||||||||
|
|
||||||||
| a, sfa = a_fp8 | ||||||||
| b, sfb = b_fp8 | ||||||||
|
|
@@ -1385,15 +1393,48 @@ def m_grouped_fp8_gemm_nt_contiguous( | |||||||
| m__ = m_indices.numel() | ||||||||
|
|
||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The check for positive dimensions
Suggested change
|
||||||||
| # Type and shape checks | ||||||||
| assert m == m_ == m__ and n == n_ and k == k_ | ||||||||
| assert n > 0 and k > 0 and num_groups > 0 | ||||||||
| assert a.dtype == torch.float8_e4m3fn | ||||||||
| assert b.dtype == torch.float8_e4m3fn | ||||||||
| assert d.dtype == torch.bfloat16 | ||||||||
| assert m_indices.dtype == torch.int32 | ||||||||
| if m != m_ or k != k_ or n != n_ or m__ != m_ or num_groups != m__: | ||||||||
| raise ValueError( | ||||||||
| f"Shape mismatch. m = {m}, m_ = {m_}, k = {k}, k_ = {k_}, n = {n}, n_ = {n_}, m__ = {m__}" | ||||||||
| ) | ||||||||
| if a.dtype != torch.float8_e4m3fn: | ||||||||
| raise ValueError(f"a must be float8_e4m3fn, but got {a.dtype}") | ||||||||
| if b.dtype != torch.float8_e4m3fn: | ||||||||
| raise ValueError(f"b must be float8_e4m3fn, but got {b.dtype}") | ||||||||
| if d.dtype != torch.bfloat16: | ||||||||
| raise ValueError(f"d must be bfloat16, but got {d.dtype}") | ||||||||
| if m_indices.dtype != torch.int32: | ||||||||
| raise ValueError(f"m_indices must be int32, but got {m_indices.dtype}") | ||||||||
|
|
||||||||
| # D must be N-major | ||||||||
| assert get_major_type_cd(d) == MajorTypeCD.NMajor | ||||||||
| if get_major_type_cd(d) != MajorTypeCD.NMajor: | ||||||||
| raise ValueError(f"d must be N-major, but got {get_major_type_cd(d)}") | ||||||||
|
|
||||||||
| return True | ||||||||
|
|
||||||||
|
|
||||||||
| @backend_requirement( | ||||||||
| {}, | ||||||||
| common_check=_check_group_deepgemm_fp8_nt_contiguous_problem_size, | ||||||||
| ) | ||||||||
| def m_grouped_fp8_gemm_nt_contiguous( | ||||||||
| a_fp8: Tuple[torch.Tensor, torch.Tensor], | ||||||||
| b_fp8: Tuple[torch.Tensor, torch.Tensor], | ||||||||
| d: torch.Tensor, | ||||||||
| m_indices: torch.Tensor, | ||||||||
| recipe: Optional[Tuple[int, int, int]] = None, | ||||||||
| compiled_dims: str = "nk", | ||||||||
| ) -> None: | ||||||||
| # Compiled dims can be upper cases | ||||||||
| compiled_dims = compiled_dims.lower() | ||||||||
|
|
||||||||
| major_a = get_major_type_ab(a_fp8[0]) | ||||||||
| major_b = get_major_type_ab(b_fp8[0]) | ||||||||
|
|
||||||||
| a, sfa = a_fp8 | ||||||||
| b, sfb = b_fp8 | ||||||||
| m, k = a.shape | ||||||||
| num_groups, n, k_ = b.shape | ||||||||
|
|
||||||||
| # Do nothing if the problem is empty | ||||||||
| if m == 0: | ||||||||
|
|
@@ -1423,6 +1464,72 @@ def m_grouped_fp8_gemm_nt_contiguous( | |||||||
| impl(a, sfa, b, sfb, d, m_indices) | ||||||||
|
|
||||||||
|
|
||||||||
| @supported_compute_capability([100, 103]) | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here, the compute capabilities will be ignored. |
||||||||
| def _check_m_grouped_fp8_gemm_nt_masked_problem_size( | ||||||||
| a_fp8: Tuple[torch.Tensor, torch.Tensor], | ||||||||
| b_fp8: Tuple[torch.Tensor, torch.Tensor], | ||||||||
| d: torch.Tensor, | ||||||||
| masked_m: torch.Tensor, | ||||||||
| expected_m: int, | ||||||||
| recipe: Optional[Tuple[int, int, int]] = None, | ||||||||
| compiled_dims: str = "nk", | ||||||||
| ) -> bool: | ||||||||
| major_a = get_major_type_ab(a_fp8[0]) | ||||||||
| major_b = get_major_type_ab(b_fp8[0]) | ||||||||
| if major_a != MajorTypeAB.KMajor: | ||||||||
| raise ValueError(f"major_a must be KMajor, but got {major_a}") | ||||||||
| if major_b != MajorTypeAB.KMajor: | ||||||||
| raise ValueError(f"major_b must be KMajor, but got {major_b}") | ||||||||
|
|
||||||||
| if not masked_m.is_contiguous(): | ||||||||
| raise ValueError( | ||||||||
| f"masked_m must be contiguous, but got {masked_m.is_contiguous()}" | ||||||||
| ) | ||||||||
|
|
||||||||
| a, sfa = a_fp8 | ||||||||
| b, sfb = b_fp8 | ||||||||
| num_groups, m, k = a.shape | ||||||||
| num_groups_, n, k_ = b.shape | ||||||||
| num_groups__, m_, n_ = d.shape | ||||||||
| num_groups___ = masked_m.numel() | ||||||||
|
|
||||||||
| # Type and shape checks | ||||||||
| if ( | ||||||||
| num_groups != num_groups_ | ||||||||
| or num_groups != num_groups__ | ||||||||
| or num_groups != num_groups___ | ||||||||
| ): | ||||||||
| raise ValueError( | ||||||||
| f"num_groups mismatch. num_groups = {num_groups}, num_groups_ = {num_groups_}, num_groups__ = {num_groups__}, num_groups___ = {num_groups___}" | ||||||||
| ) | ||||||||
| if m != m_ or n != n_ or k != k_: | ||||||||
| raise ValueError( | ||||||||
| f"m, n, k mismatch. m = {m}, m_ = {m_}, n = {n}, n_ = {n_}, k = {k}, k_ = {k_}" | ||||||||
| ) | ||||||||
| if expected_m <= 0 or m <= 0 or n <= 0 or k <= 0 or num_groups <= 0: | ||||||||
| raise ValueError( | ||||||||
| f"expected_m, m, n, k, num_groups must be greater than 0, but got expected_m = {expected_m}, m = {m}, n = {n}, k = {k}, num_groups = {num_groups}" | ||||||||
| ) | ||||||||
| if a.dtype != torch.float8_e4m3fn: | ||||||||
| raise ValueError(f"a must be float8_e4m3fn, but got {a.dtype}") | ||||||||
| if b.dtype != torch.float8_e4m3fn: | ||||||||
| raise ValueError(f"b must be float8_e4m3fn, but got {b.dtype}") | ||||||||
| if d.dtype != torch.bfloat16: | ||||||||
| raise ValueError(f"d must be bfloat16, but got {d.dtype}") | ||||||||
| if masked_m.dtype != torch.int32: | ||||||||
| raise ValueError(f"masked_m must be int32, but got {masked_m.dtype}") | ||||||||
|
|
||||||||
| # D must be N-major | ||||||||
| if get_major_type_cd(d) != MajorTypeCD.NMajor: | ||||||||
| raise ValueError(f"d must be N-major, but got {get_major_type_cd(d)}") | ||||||||
|
|
||||||||
| return True | ||||||||
|
|
||||||||
|
|
||||||||
| @backend_requirement( | ||||||||
| {}, | ||||||||
| common_check=_check_m_grouped_fp8_gemm_nt_masked_problem_size, | ||||||||
| ) | ||||||||
| def m_grouped_fp8_gemm_nt_masked( | ||||||||
| a_fp8: Tuple[torch.Tensor, torch.Tensor], | ||||||||
| b_fp8: Tuple[torch.Tensor, torch.Tensor], | ||||||||
|
|
@@ -1445,20 +1552,6 @@ def m_grouped_fp8_gemm_nt_masked( | |||||||
| b, sfb = b_fp8 | ||||||||
| num_groups, m, k = a.shape | ||||||||
| num_groups_, n, k_ = b.shape | ||||||||
| num_groups__, m_, n_ = d.shape | ||||||||
| num_groups___ = masked_m.numel() | ||||||||
|
|
||||||||
| # Type and shape checks | ||||||||
| assert num_groups == num_groups_ == num_groups__ == num_groups___ | ||||||||
| assert m == m_ and n == n_ and k == k_ | ||||||||
| assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0 | ||||||||
| assert a.dtype == torch.float8_e4m3fn | ||||||||
| assert b.dtype == torch.float8_e4m3fn | ||||||||
| assert d.dtype == torch.bfloat16 | ||||||||
| assert masked_m.dtype == torch.int32 | ||||||||
|
|
||||||||
| # D must be N-major | ||||||||
| assert get_major_type_cd(d) == MajorTypeCD.NMajor | ||||||||
|
|
||||||||
| # Transform SFA and SFB into compute-required layout | ||||||||
| recipe = get_default_recipe(sfa.dtype, sfb.dtype) if recipe is None else recipe | ||||||||
|
|
||||||||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, the compute capability will not be checked actually on the common_check function.
We would either need to:
My preference would go to option 2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think 2 also makes most sense. In a lot of the APIs there are also no 'backend' arg to be passed in so we can't only check @ supported_compute_capability there. I can change this in a separate PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A separate PR is fine, since we wouldn't cause a regression (we didn't have CC checks before the current PR anyway).