Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
131 commits
Select commit Hold shift + click to select a range
604cfd4
bmm_fp8
jimmyzho Oct 16, 2025
499dcc5
bmm_fp8
jimmyzho Oct 16, 2025
ad39f67
gemm.py refactor
jimmyzho Oct 29, 2025
ebb610c
unittest: Add head dim 256 test cases and mark as xfail (#1999)
bkryu Oct 29, 2025
bb6b620
feat: autotune tile_tokens_dim in trtllm-gen MOE (#1980)
jiahanc Oct 29, 2025
6a962ef
Fix trtllm-gen attention illegal memory access (#2002)
Tom-Zheng Oct 29, 2025
3c07921
release: Bump version for v0.5.0rc1 release; (#2008)
bkryu Oct 30, 2025
b9287c9
bugfix: fix regex in update wheel index script (#2009)
yzh119 Oct 30, 2025
a5ff033
fix: Enable SM121 for mm_fp4 (#2012)
bkryu Oct 30, 2025
de4c701
fix: ensure SM120/121 SFA/SFB contiguity (#1963)
yongwww Oct 31, 2025
1181c5d
More realistic bench for POD Attn (#2013)
Edenzzzz Oct 31, 2025
f9cd034
Feature: Support non-gated activation in cutlass fused MoE nvfp4 (#2011)
omera-nv Oct 31, 2025
5854494
feat: add xqa backend and completes NHD/HND coverage for trtllm-gen/x…
qsang-nv Nov 2, 2025
da01b1b
test: Enable xfailed trtllm decode long seqlen tests and update micro…
bkryu Nov 2, 2025
1e75bff
Updated decorator to support unspecified default (#2026)
nvmbreughe Nov 3, 2025
2d68a6b
release: Bump version for v0.5.1 release (#2031)
bkryu Nov 4, 2025
d528f0c
ci: Update cudnn version requirements in CI container (#2039)
bkryu Nov 4, 2025
f2cc526
test: Mark test_fp8_prefill.py as xfail on SM90 (#2038)
bkryu Nov 4, 2025
e1c1e2a
Update Docker CI tags to 20251104-d528f0c (#2041)
flashinfer-bot Nov 5, 2025
9bc5bd5
bugfix: fix failed unittest `test_green_ctx` and `test_jit_example` o…
yzh119 Nov 5, 2025
2580610
perf: Speed up fp4 quantization for small batch with swizzling for cu…
bkryu Nov 5, 2025
579012b
Support cc common check decorator for empty backends (#2015)
jimmyzho Nov 5, 2025
6d19a75
use scalar for kv_scale in xqa (#2033)
qsang-nv Nov 5, 2025
9721ff7
fix: support both pip and uv pip for finding flashinfer-python packag…
djmmoss Nov 5, 2025
747b4e2
test: Fix test_sampling.py on Spark (#2042)
bkryu Nov 6, 2025
adb0e89
Fix dtype of output scales from mnnvl_moe_alltoallv_prepare_without_a…
trevor-m Nov 6, 2025
b211926
Update trtllm-gen fused moe routing kernel and add more kernels (#1955)
jiahanc Nov 6, 2025
26d587a
chore: Update CODEOWNERS (#1984)
flashinfer-bot Nov 6, 2025
aacc8df
Add support for topkPacked input in block-level renormalize (#2051)
ChristinaZ Nov 6, 2025
f25929f
test: Skip test_fp8_quantize.py on Hopper (#2052)
bkryu Nov 6, 2025
55ea787
[BUG] Fix trtllm-gen fp4 moe renormalize routing (#2049)
IwakuraRein Nov 6, 2025
63cf562
release: Bump version for v0.5.2 release (#2057)
bkryu Nov 7, 2025
adcc5dd
perf: improve sampling/mask/softmax performance (part 1/2) (#2044)
yzh119 Nov 7, 2025
f566d49
misc: Add XQA decode to microbenchmark for sm90 and sm120 (#2055)
bkryu Nov 7, 2025
36d2463
test: Skip unsupported SM Archs for newly added trtllm MoE test (#2060)
bkryu Nov 7, 2025
3cb8f9a
feat: suitable_auto_backends to prune auto backends, bmm_fp8 refactor…
jimmyzho Nov 7, 2025
20435b4
update trtllm cutlass moe (#2020)
nv-yunzheq Nov 7, 2025
f588d96
perf: Optimize helper max/minmax function in sampling.cuh (#2058)
bkryu Nov 7, 2025
c8f2b03
[DSV3] Optimized Router Gemm (#2019)
nvmbreughe Nov 7, 2025
e450c7d
Fix moe fp8 failure for sm121 (#2061)
yongwww Nov 7, 2025
ba011d1
perf: TRT-LLM MoE Block-FP8 activation optimization (#2063)
nekorobov Nov 8, 2025
74281ed
[feat] Refactor trtllmgen MOE and add Bf16 trtllmgen moe (#2014)
jiahanc Nov 8, 2025
d56748f
Fix: several bugs/issues with trtllm-gen attention kernels. (#2062)
PerkzZheng Nov 9, 2025
8d7d0bc
refactor: remove MetaInfoHash class (#2064)
yzh119 Nov 9, 2025
f5a06a4
chore: Update CODEOWNERS (#2067)
flashinfer-bot Nov 10, 2025
d42fb90
feat: add xqa mla backend (#2053)
qsang-nv Nov 10, 2025
fbdb439
Enable renormalize(naive) routing for fp8 per-tensor (#2030)
IwakuraRein Nov 11, 2025
11177e8
unittest: improve the efficiency of xqa unittests (#2075)
yzh119 Nov 11, 2025
eccbdde
minor: canonicalize TFLOPS calculation (#2069)
Edenzzzz Nov 12, 2025
96e73b8
fix: fix test_trtllm_gen_attention when max_seq_len < page_size (#2076)
dongjiyingdjy Nov 12, 2025
53a6da4
enable xqa fp8 output (#2081)
qsang-nv Nov 12, 2025
abf6a14
chore: update requires-python in pyproject.toml (#2080)
raayandhar Nov 12, 2025
6765cad
[Test] Optimize test_trtllm_gen_fused_moe.py (#2072)
jiahanc Nov 12, 2025
b433fc7
test: Change incorrect inputs in test_hopper.py (#2083)
bkryu Nov 13, 2025
54101e9
[NVIDIA] Thor & Spark Support (#2028)
johnnynunez Nov 13, 2025
9a79b78
[API change] deprecate tile_token_dim in trtllm_moe (#2086)
jiahanc Nov 14, 2025
636a3ab
[Feature] Support batch prefill for POD Attention (#2079)
AKKamath Nov 14, 2025
37434ed
feat: patch sm103 for 3xfp4 moe generation (#2082)
aleozlx Nov 14, 2025
ba8f3ed
MNNVL All Reduce for large number of tokens (#2074)
nvmbreughe Nov 14, 2025
cce4952
perf: TRT-LLM Gen finalize kernel optimization (#2092)
nekorobov Nov 14, 2025
4ddf71d
refactor: update dpsk fused_moe test [1] (#2088)
yyihuang Nov 16, 2025
d42b71f
chore: update thor cuda arch (from 110f to 110a) (#2096)
yzh119 Nov 16, 2025
4aed50c
perf: enable pdl for cutlass fp4 gemm (#2095)
yzh119 Nov 16, 2025
0a36050
chore: Update CODEOWNERS (#2098)
flashinfer-bot Nov 18, 2025
3b07247
feat: Add flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache (fu…
kahyunnam Nov 18, 2025
a9f71bd
[API change] Allow using torch.Tensor for scales for trtllm-gen atten…
IwakuraRein Nov 18, 2025
875403e
refactor: update dpsk fused_moe test [2] (#2097)
yyihuang Nov 18, 2025
1c4b522
hotfix: rename moe/test_utils.py to moe/utils.py (#2106)
yzh119 Nov 18, 2025
219592b
[DSR1] Added MLA test (#2100)
nvmbreughe Nov 19, 2025
b9964cc
test: Enable testing for trtllm-gen decode bs1 (#2103)
bkryu Nov 19, 2025
3a23405
[DSV3] Optimized routing kernels dsv3 (#2099)
nv-yunzheq Nov 19, 2025
0753095
feature: make the LSE returned by MLA support base 2 or e #2113 (#2114)
staugust Nov 20, 2025
76eea79
update xqa license (#2117)
qsang-nv Nov 20, 2025
af25b45
add tensor scale input for xqa (#2110)
qsang-nv Nov 20, 2025
049e8db
hotfix: add 9.0a to README and installation doc (#2112)
yzh119 Nov 20, 2025
2628beb
ci/cd: add nvidia-ml-py to requirments of build-system of flashinfer-…
yzh119 Nov 20, 2025
0aee7af
feat: Add backend='auto' to mm_fp4 and enable autotune for backend='c…
bkryu Nov 21, 2025
7128c7b
fix: Fix bench_mm_fp8.py (#2129)
bkryu Nov 21, 2025
5acb57b
feat: Enable API Logging for Better Debugging POC (#2108)
bkryu Nov 22, 2025
5e11004
fix: add a check for int32 indices in sampling.py (#2127)
raayandhar Nov 22, 2025
84df81e
update autotuner input tensor random range (#2116)
jiahanc Nov 22, 2025
2439a41
enable xqa speculative decoding (#2105)
qsang-nv Nov 22, 2025
d56be0d
Add custom communicator for trtllm_mnnvl_ar (#2056)
wenscarl Nov 22, 2025
cf2df82
fix: DeepSeek activation uninitialized data (#2128)
nekorobov Nov 22, 2025
9f13e83
chore: Update CODEOWNERS (#2135)
flashinfer-bot Nov 24, 2025
ecd4ef1
bugfix: fix unittest error introduced in #2056 (#2136)
yzh119 Nov 24, 2025
efd8554
fix flaky xqa test (#2126)
qsang-nv Nov 25, 2025
fd5273c
fix: some bugs of headDim 256 trtllm-gen fmha kernels. (#2137)
PerkzZheng Nov 25, 2025
aeeccac
fix(trtllm): reset negative strideBatch to 0 for ragged KV layout to …
YAMY1234 Nov 25, 2025
1940b28
feat: add trtllm-gen per-tensor sparseMla kernels. (#2138)
PerkzZheng Nov 25, 2025
d0d99d2
Use global TuningConfig, to fix memory leak caused by AutoTuner LRU c…
juju812 Nov 25, 2025
18004a8
feat: add seed offset args to sampler to allow cuda graph support (#2…
ksukrit Nov 26, 2025
df5c2e4
ci: Reduce test time by moving compilation off-line (#2089)
kahyunnam Nov 28, 2025
b14408b
feat: TRTLLM FMHAv2 backend for ctx attention (#2142)
jimmyzho Nov 28, 2025
dc37789
refactor: pass hopper deepgemm include directory through python (#2090)
yzh119 Dec 1, 2025
e59226b
bugfix: add driver support to CUPTI benchmark function, issue #2145 (…
nv-yunzheq Dec 2, 2025
23ff744
Bump tvm ffi version to 0.1.4 (#2155)
cyx-6 Dec 2, 2025
89e1adb
Update Docker CI tags to 20251202-23ff744 (#2158)
flashinfer-bot Dec 2, 2025
b56bb4f
misc: Label APIs for Logging (#2153)
bkryu Dec 2, 2025
4efb7bb
Update nvidia-cutlass-dsl version to 4.3.1 (#2161)
aleozlx Dec 3, 2025
685db69
chore: Update CODEOWNERS (#2152)
flashinfer-bot Dec 3, 2025
9ac59e5
feat: C++ side tensor validation (#2160)
raayandhar Dec 3, 2025
442dec9
Update Docker CI tags to 20251203-4efb7bb (#2164)
flashinfer-bot Dec 3, 2025
1e15fed
ci: Install CUDA version specified torch first during container build…
bkryu Dec 3, 2025
890bb46
fix xqa mha_sm90.cu (#2157)
qsang-nv Dec 4, 2025
cf6962a
Update Docker CI tags to 20251203-1e15fed (#2172)
flashinfer-bot Dec 4, 2025
40bc6e1
enable sm103 moe dsl backend (#2149)
aleozlx Dec 4, 2025
cdc5fb7
ci: Use stable Torch Release for cu130 (#2174)
bkryu Dec 4, 2025
eec483b
tiny upd `mm_fp4` docstring (#2177)
b8zhong Dec 5, 2025
cc50469
fix: compile flags for trtllm fmha_v2 (#2175)
jimmyzho Dec 5, 2025
54c1678
Fix/dsl smem query (#2178)
aleozlx Dec 5, 2025
4db4ac0
Update Docker CI tags to 20251204-cdc5fb7 (#2176)
flashinfer-bot Dec 5, 2025
6930085
feat: MxInt4 x Bf16 TRT-LLM Gen MoE support (#2159)
nekorobov Dec 5, 2025
1733727
refactor: Move mla code from decode.py to mla.py and add to documenta…
bkryu Dec 5, 2025
b972005
Fix gemm allreduce two shot (#2171)
aleozlx Dec 5, 2025
88328e8
Update Docker CI tags to 20251205-54c1678 (#2179)
flashinfer-bot Dec 5, 2025
ad3f26b
Rename noauxtc to fused_topk_deepseek (#2181)
nv-yunzheq Dec 6, 2025
09872a1
refactor: update fa3 codebase and fix hopper unittest [part 1] (#2111)
yzh119 Dec 6, 2025
70bc2b5
Add data type check for deepseek fp4 moe (#2165)
samuellees Dec 6, 2025
6dfc1ba
benchmark: Make use_cupti the default in microbenchmarks. (#2180)
bkryu Dec 6, 2025
185d63a
ci: Specify MPI implementation to mpich (#2182)
bkryu Dec 6, 2025
5fe01a2
Update Docker CI tags to 20251206-185d63a (#2184)
flashinfer-bot Dec 7, 2025
abcd8e0
test: Skip sm90 test in test_jit_warmup.py if not on sm90 (#2189)
bkryu Dec 9, 2025
6bb01d1
ci: Update sm12X minimum cuda capability to 12.9 in aot.py (#2188)
bkryu Dec 9, 2025
9f1cb89
Super tiny fix version (#2199)
fzyzcjy Dec 11, 2025
dc0ade7
docs: Document CUDA version support in README and installation page (…
bkryu Dec 11, 2025
3a29167
docs: Fix inaccurate API docstrings for attention prefill (#2196)
bkryu Dec 11, 2025
fae79ab
feat: unit-test and api change, w4a8 grouped-gemm fused MoE for SM90 …
jimmyzho Dec 11, 2025
f2b9cc0
Permute page table in benchmarking (#2194)
jhjpark Dec 11, 2025
8f4e806
bugfix: Fix for moe on sm110 (#2190)
jhalabi-nv Dec 12, 2025
25fe3eb
Merge branch 'bknd' of github.com:jimmyzho/flashinfer into bknd
jimmyzho Dec 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 124 additions & 31 deletions flashinfer/deep_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,12 @@
from .cuda_utils import checkCudaErrors
from .jit.cubin_loader import get_cubin
from .jit.env import FLASHINFER_CUBIN_DIR
from .utils import ceil_div, round_up
from .utils import (
ceil_div,
round_up,
supported_compute_capability,
backend_requirement,
)


class GemmType(enum.Enum):
Expand Down Expand Up @@ -1358,24 +1363,27 @@ def m_grouped_fp8_gemm_nt_masked_sm10x(
runtime(**all_kwargs)


def m_grouped_fp8_gemm_nt_contiguous(
@supported_compute_capability([100, 103])
Copy link
Contributor

@nvmbreughe nvmbreughe Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, the compute capability will not be checked actually on the common_check function.
We would either need to:

  1. Redesign the decorator so that we can just add @supported_compute_capability to m_grouped_fp8_gemm_nt_contiguous directly; OR
  2. Change the decorator's implementations of wrapper and is_compute_capability_supported so it also checks this on the common_check.
  3. Add a backend parameter to m_grouped_fp8_gemm_nt_contiguous. Though this is deepgemm and I don't think we want to call this a backend for now.

My preference would go to option 2

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think 2 also makes most sense. In a lot of the APIs there are also no 'backend' arg to be passed in so we can't only check @ supported_compute_capability there. I can change this in a separate PR

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A separate PR is fine, since we wouldn't cause a regression (we didn't have CC checks before the current PR anyway).

def _check_group_deepgemm_fp8_nt_contiguous_problem_size(
a_fp8: Tuple[torch.Tensor, torch.Tensor],
b_fp8: Tuple[torch.Tensor, torch.Tensor],
d: torch.Tensor,
m_indices: torch.Tensor,
recipe: Optional[Tuple[int, int, int]] = None,
compiled_dims: str = "nk",
) -> None:
# Compiled dims can be upper cases
compiled_dims = compiled_dims.lower()

) -> bool:
# NOTES: shape must be `[M, K] @ [G, N, K].mT`
major_a = get_major_type_ab(a_fp8[0])
major_b = get_major_type_ab(b_fp8[0])
assert major_a == MajorTypeAB.KMajor
if must_be_k_major():
assert major_b == MajorTypeAB.KMajor
assert m_indices.is_contiguous()
if major_a != MajorTypeAB.KMajor:
raise ValueError(f"major_a must be KMajor, but got {major_a}")
if must_be_k_major() and (major_b != MajorTypeAB.KMajor):
raise ValueError(f"major_b must be KMajor, but got {major_b}")

if not m_indices.is_contiguous():
raise ValueError(
f"m_indices must be contiguous, but got {m_indices.is_contiguous()}"
)

a, sfa = a_fp8
b, sfb = b_fp8
Expand All @@ -1385,15 +1393,48 @@ def m_grouped_fp8_gemm_nt_contiguous(
m__ = m_indices.numel()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The check for positive dimensions n, k, and num_groups is missing. The original code had assert n > 0 and k > 0 and num_groups > 0. This check is important to prevent unexpected behavior with empty or invalid dimensions.

Suggested change
if n <= 0 or k <= 0 or num_groups <= 0:
raise ValueError(f"n, k, and num_groups must be positive, but got n={n}, k={k}, num_groups={num_groups}")

# Type and shape checks
assert m == m_ == m__ and n == n_ and k == k_
assert n > 0 and k > 0 and num_groups > 0
assert a.dtype == torch.float8_e4m3fn
assert b.dtype == torch.float8_e4m3fn
assert d.dtype == torch.bfloat16
assert m_indices.dtype == torch.int32
if m != m_ or k != k_ or n != n_ or m__ != m_ or num_groups != m__:
raise ValueError(
f"Shape mismatch. m = {m}, m_ = {m_}, k = {k}, k_ = {k_}, n = {n}, n_ = {n_}, m__ = {m__}"
)
if a.dtype != torch.float8_e4m3fn:
raise ValueError(f"a must be float8_e4m3fn, but got {a.dtype}")
if b.dtype != torch.float8_e4m3fn:
raise ValueError(f"b must be float8_e4m3fn, but got {b.dtype}")
if d.dtype != torch.bfloat16:
raise ValueError(f"d must be bfloat16, but got {d.dtype}")
if m_indices.dtype != torch.int32:
raise ValueError(f"m_indices must be int32, but got {m_indices.dtype}")

# D must be N-major
assert get_major_type_cd(d) == MajorTypeCD.NMajor
if get_major_type_cd(d) != MajorTypeCD.NMajor:
raise ValueError(f"d must be N-major, but got {get_major_type_cd(d)}")

return True


@backend_requirement(
{},
common_check=_check_group_deepgemm_fp8_nt_contiguous_problem_size,
)
def m_grouped_fp8_gemm_nt_contiguous(
a_fp8: Tuple[torch.Tensor, torch.Tensor],
b_fp8: Tuple[torch.Tensor, torch.Tensor],
d: torch.Tensor,
m_indices: torch.Tensor,
recipe: Optional[Tuple[int, int, int]] = None,
compiled_dims: str = "nk",
) -> None:
# Compiled dims can be upper cases
compiled_dims = compiled_dims.lower()

major_a = get_major_type_ab(a_fp8[0])
major_b = get_major_type_ab(b_fp8[0])

a, sfa = a_fp8
b, sfb = b_fp8
m, k = a.shape
num_groups, n, k_ = b.shape

# Do nothing if the problem is empty
if m == 0:
Expand Down Expand Up @@ -1423,6 +1464,72 @@ def m_grouped_fp8_gemm_nt_contiguous(
impl(a, sfa, b, sfb, d, m_indices)


@supported_compute_capability([100, 103])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, the compute capabilities will be ignored.

def _check_m_grouped_fp8_gemm_nt_masked_problem_size(
a_fp8: Tuple[torch.Tensor, torch.Tensor],
b_fp8: Tuple[torch.Tensor, torch.Tensor],
d: torch.Tensor,
masked_m: torch.Tensor,
expected_m: int,
recipe: Optional[Tuple[int, int, int]] = None,
compiled_dims: str = "nk",
) -> bool:
major_a = get_major_type_ab(a_fp8[0])
major_b = get_major_type_ab(b_fp8[0])
if major_a != MajorTypeAB.KMajor:
raise ValueError(f"major_a must be KMajor, but got {major_a}")
if major_b != MajorTypeAB.KMajor:
raise ValueError(f"major_b must be KMajor, but got {major_b}")

if not masked_m.is_contiguous():
raise ValueError(
f"masked_m must be contiguous, but got {masked_m.is_contiguous()}"
)

a, sfa = a_fp8
b, sfb = b_fp8
num_groups, m, k = a.shape
num_groups_, n, k_ = b.shape
num_groups__, m_, n_ = d.shape
num_groups___ = masked_m.numel()

# Type and shape checks
if (
num_groups != num_groups_
or num_groups != num_groups__
or num_groups != num_groups___
):
raise ValueError(
f"num_groups mismatch. num_groups = {num_groups}, num_groups_ = {num_groups_}, num_groups__ = {num_groups__}, num_groups___ = {num_groups___}"
)
if m != m_ or n != n_ or k != k_:
raise ValueError(
f"m, n, k mismatch. m = {m}, m_ = {m_}, n = {n}, n_ = {n_}, k = {k}, k_ = {k_}"
)
if expected_m <= 0 or m <= 0 or n <= 0 or k <= 0 or num_groups <= 0:
raise ValueError(
f"expected_m, m, n, k, num_groups must be greater than 0, but got expected_m = {expected_m}, m = {m}, n = {n}, k = {k}, num_groups = {num_groups}"
)
if a.dtype != torch.float8_e4m3fn:
raise ValueError(f"a must be float8_e4m3fn, but got {a.dtype}")
if b.dtype != torch.float8_e4m3fn:
raise ValueError(f"b must be float8_e4m3fn, but got {b.dtype}")
if d.dtype != torch.bfloat16:
raise ValueError(f"d must be bfloat16, but got {d.dtype}")
if masked_m.dtype != torch.int32:
raise ValueError(f"masked_m must be int32, but got {masked_m.dtype}")

# D must be N-major
if get_major_type_cd(d) != MajorTypeCD.NMajor:
raise ValueError(f"d must be N-major, but got {get_major_type_cd(d)}")

return True


@backend_requirement(
{},
common_check=_check_m_grouped_fp8_gemm_nt_masked_problem_size,
)
def m_grouped_fp8_gemm_nt_masked(
a_fp8: Tuple[torch.Tensor, torch.Tensor],
b_fp8: Tuple[torch.Tensor, torch.Tensor],
Expand All @@ -1445,20 +1552,6 @@ def m_grouped_fp8_gemm_nt_masked(
b, sfb = b_fp8
num_groups, m, k = a.shape
num_groups_, n, k_ = b.shape
num_groups__, m_, n_ = d.shape
num_groups___ = masked_m.numel()

# Type and shape checks
assert num_groups == num_groups_ == num_groups__ == num_groups___
assert m == m_ and n == n_ and k == k_
assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0
assert a.dtype == torch.float8_e4m3fn
assert b.dtype == torch.float8_e4m3fn
assert d.dtype == torch.bfloat16
assert masked_m.dtype == torch.int32

# D must be N-major
assert get_major_type_cd(d) == MajorTypeCD.NMajor

# Transform SFA and SFB into compute-required layout
recipe = get_default_recipe(sfa.dtype, sfb.dtype) if recipe is None else recipe
Expand Down
Loading