Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 127 additions & 72 deletions primus_turbo/flydsl/gemm/gemm_fp8_kernel.py

Large diffs are not rendered by default.

Empty file.
2,244 changes: 2,244 additions & 0 deletions primus_turbo/flydsl/grouped_gemm/gemm_fp8_grouped_kernel.py

Large diffs are not rendered by default.

243 changes: 210 additions & 33 deletions primus_turbo/flydsl/utils/gemm_helper.py

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions primus_turbo/pytorch/kernels/gemm/gemm_fp8_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,10 @@ def can_handle(
# StoreC clamp + the global SRD.)
k = a.shape[0] if trans_a else a.shape[1]
supported &= k >= 129
# No size cap: foldable operands (NT both, NN-A) fold their per-tile base into
# the i64 SRD; the traversal operands (NN-B k*n, TN k*m & k*n) that would wrap a
# 32-bit soffset past 2^32 fp8 are re-based per load in i64 by the wrapper (it
# auto-selects i64 at/above 2^32, keeping the cheaper int32 path below).
# per-tensor scalar scale (wrapper broadcasts to vector internally)
supported &= a_scale_inv.numel() == 1 and b_scale_inv.numel() == 1
return supported
Expand Down
138 changes: 138 additions & 0 deletions primus_turbo/pytorch/kernels/grouped_gemm/grouped_gemm_fp8_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
float8_e4m3,
float8_e5m2,
)
from primus_turbo.pytorch.core.utils import get_device_compute_capability
from primus_turbo.pytorch.kernels.grouped_gemm.grouped_gemm_utils import (
BaseGroupedGEMMKernelDispatcher,
BaseGroupedGEMMVariableKKernelDispatcher,
Expand Down Expand Up @@ -425,11 +426,74 @@ def execute(
)


class GroupedGEMMFP8FlyDSLBackend(KernelBackend):
"""FlyDSL fp8 grouped GEMM backend (gfx950, per-tensor / TENSORWISE only).

M-grouped operator: forward (trans_b=True, NT) + dgrad (trans_b=False, NN).
Uses the FlyDSL mfma_f32_16x16x128_f8f6f4 kernel (gfx950-only).
"""
Comment on lines +430 to +434

SUPPORTED_GRANULARITIES = {ScalingGranularity.TENSORWISE}
SUPPORTED_DTYPES = set(_COMMON_SUPPORTED_DTYPES + _HYBRID_SUPPORTED_DTYPES)

@staticmethod
def can_handle(
a: torch.Tensor,
b: torch.Tensor,
a_scales: torch.Tensor,
b_scales: torch.Tensor,
group_lens: torch.Tensor,
group_offs: torch.Tensor,
trans_a: bool,
trans_b: bool,
out_dtype: torch.dtype,
granularity: ScalingGranularity,
num_cu: int | None,
**kwargs,
) -> bool:
supported = True
supported &= a.dim() == 2 and b.dim() == 3
supported &= (a.dtype, b.dtype, out_dtype) in GroupedGEMMFP8FlyDSLBackend.SUPPORTED_DTYPES
supported &= granularity in GroupedGEMMFP8FlyDSLBackend.SUPPORTED_GRANULARITIES
supported &= not trans_a
# per-tensor scaling = single scalar each
supported &= a_scales.numel() == 1 and b_scales.numel() == 1
# gfx950 (CDNA4) only: kernel uses mfma_f32_16x16x128_f8f6f4.
supported &= get_device_compute_capability() >= (9, 5)
# K-loop needs ceil(K/128) >= 2, i.e. contraction K >= 129.
supported &= a.shape[1] >= 129
return supported

@staticmethod
def execute(
a: torch.Tensor,
b: torch.Tensor,
a_scales: torch.Tensor,
b_scales: torch.Tensor,
group_lens: torch.Tensor,
group_offs: torch.Tensor,
trans_a: bool,
trans_b: bool,
out_dtype: torch.dtype,
granularity: ScalingGranularity,
num_cu: int | None,
**kwargs,
):
from primus_turbo.flydsl.grouped_gemm.gemm_fp8_grouped_kernel import (
grouped_gemm_fp8_tensorwise_flydsl_kernel,
)

return grouped_gemm_fp8_tensorwise_flydsl_kernel(
a, b, a_scales, b_scales, group_offs, trans_b=trans_b, out_dtype=out_dtype, num_cu=num_cu
)


class GroupedGEMMFP8KernelDispatcher(BaseGroupedGEMMKernelDispatcher):
_backends = {
BackendType.CK: BackendEntry(GroupedGEMMFP8CKBackend),
BackendType.HIPBLASLT: BackendEntry(GroupedGEMMFP8HipblasltBackend, autotune=False),
BackendType.TRITON: BackendEntry(GroupedGEMMFP8TritonBackend),
BackendType.FLYDSL: BackendEntry(GroupedGEMMFP8FlyDSLBackend),
Comment thread
kyle-256 marked this conversation as resolved.
Comment thread
kyle-256 marked this conversation as resolved.
}
_cache = TuneCache(1024)

Expand Down Expand Up @@ -580,11 +644,85 @@ def execute(
)


class GroupedGEMMFP8VariableKFlyDSLBackend(KernelBackend):
"""FlyDSL fp8 variable-K grouped GEMM backend (gfx950, per-tensor only).

wgrad: C[g] = lhs[offs[g]:offs[g+1]]^T @ rhs[offs[g]:offs[g+1]], contraction
= m_g (variable per group) via a runtime scf.for K-loop. Uses the FlyDSL
mfma_f32_16x16x128_f8f6f4 TN kernel (gfx950-only).
"""

SUPPORTED_GRANULARITIES = {ScalingGranularity.TENSORWISE}
SUPPORTED_DTYPES = set(_COMMON_SUPPORTED_DTYPES + _HYBRID_SUPPORTED_DTYPES)

@staticmethod
def can_handle(
a: torch.Tensor,
b: torch.Tensor,
a_scales: torch.Tensor,
b_scales: torch.Tensor,
group_lens: torch.Tensor,
group_offs: torch.Tensor,
trans_a: bool,
trans_b: bool,
trans_c: bool,
out_dtype: torch.dtype,
granularity: ScalingGranularity,
num_cu: int | None,
**kwargs,
) -> bool:
supported = True
supported &= a.dim() == 2 and b.dim() == 2
supported &= (a.dtype, b.dtype, out_dtype) in GroupedGEMMFP8VariableKFlyDSLBackend.SUPPORTED_DTYPES
supported &= granularity in GroupedGEMMFP8VariableKFlyDSLBackend.SUPPORTED_GRANULARITIES
# variable-K contract: contraction along the shared (rows) dim.
supported &= trans_a and not trans_b
# per-tensor scaling = single scalar each
supported &= a_scales.numel() == 1 and b_scales.numel() == 1
# gfx950 (CDNA4) only: kernel uses mfma_f32_16x16x128_f8f6f4.
supported &= get_device_compute_capability() >= (9, 5)
return supported

@staticmethod
def execute(
a: torch.Tensor,
b: torch.Tensor,
a_scales: torch.Tensor,
b_scales: torch.Tensor,
group_lens: torch.Tensor,
group_offs: torch.Tensor,
trans_a: bool,
trans_b: bool,
trans_c: bool,
out_dtype: torch.dtype,
granularity: ScalingGranularity,
num_cu: int | None,
**kwargs,
):
from primus_turbo.flydsl.grouped_gemm.gemm_fp8_grouped_kernel import (
grouped_gemm_fp8_variable_k_tensorwise_flydsl_kernel,
)

# trans_c swaps which operand is lhs (output transpose), mirroring the
# Triton variable-K backend: out[g] = lhs[g]^T @ rhs[g].
if trans_c:
lhs, rhs = b, a
lhs_scales, rhs_scales = b_scales, a_scales
else:
lhs, rhs = a, b
lhs_scales, rhs_scales = a_scales, b_scales

return grouped_gemm_fp8_variable_k_tensorwise_flydsl_kernel(
lhs, rhs, lhs_scales, rhs_scales, group_offs, out_dtype=out_dtype, num_cu=num_cu
)


class GroupedGEMMFP8VariableKKernelDispatcher(BaseGroupedGEMMVariableKKernelDispatcher):
_backends = {
BackendType.CK: BackendEntry(GroupedGEMMFP8VariableKCKBackend),
BackendType.HIPBLASLT: BackendEntry(GroupedGEMMFP8VariableKHipblasltBackend),
BackendType.TRITON: BackendEntry(GroupedGEMMFP8VariableKTritonBackend),
BackendType.FLYDSL: BackendEntry(GroupedGEMMFP8VariableKFlyDSLBackend),
Comment thread
kyle-256 marked this conversation as resolved.
}
_cache = TuneCache(1024)

Expand Down
18 changes: 15 additions & 3 deletions tests/pytorch/ops/test_grouped_gemm_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,9 +306,13 @@ def _run_once():
@pytest.mark.parametrize("format", FORMAT_VALUES)
@pytest.mark.parametrize("trans_b", TRANS_B_VALUES)
@pytest.mark.parametrize("balance", BALANCE_VALUES)
@pytest.mark.parametrize("backend", [BackendType.CK, BackendType.HIPBLASLT, BackendType.TRITON])
@pytest.mark.parametrize(
"backend", [BackendType.CK, BackendType.HIPBLASLT, BackendType.TRITON, BackendType.FLYDSL]
)
@pytest.mark.deterministic
def test_grouped_gemm_fp8_tensorwise_deterministic(B, M, NK, ori_dtype, format, trans_b, balance, backend):
if backend == BackendType.FLYDSL and get_device_compute_capability() < (9, 5):
pytest.skip("FlyDSL fp8 grouped GEMM is gfx950-only")
Comment on lines +314 to +315
N, K = NK
_run_grouped_gemm_fp8_deterministic_test(
B=B,
Expand Down Expand Up @@ -423,10 +427,14 @@ def test_grouped_gemm_fp8_mx_blockwise_deterministic(B, M, NK, ori_dtype, format
@pytest.mark.parametrize("format", FORMAT_VALUES + [Format.HYBRID])
@pytest.mark.parametrize("trans_b", TRANS_B_VALUES)
@pytest.mark.parametrize("balance", BALANCE_VALUES)
@pytest.mark.parametrize("backend", [None, BackendType.CK, BackendType.HIPBLASLT, BackendType.TRITON])
@pytest.mark.parametrize(
"backend", [None, BackendType.CK, BackendType.HIPBLASLT, BackendType.TRITON, BackendType.FLYDSL]
)
@pytest.mark.parametrize("auto_tune", [False, True])
def test_grouped_gemm_fp8_tensorwise(B, M, NK, ori_dtype, format, trans_b, balance, backend, auto_tune):

if backend == BackendType.FLYDSL and get_device_compute_capability() < (9, 5):
pytest.skip("FlyDSL fp8 grouped GEMM is gfx950-only")
Comment on lines +436 to +437
# TODO(xiaobochen-amd): On gfx942, the hipBLASLt path can hang/flake when M <= 512.
# This has been observed under pytest; root cause not yet identified. MI355 works normally.
# Skip also when auto_tune=True because the tuner may select hipBLASLt.
Expand Down Expand Up @@ -660,12 +668,16 @@ def _run_grouped_gemm_fp8_quantized_tensor_test(
@pytest.mark.parametrize("format", FORMAT_VALUES + [Format.HYBRID])
@pytest.mark.parametrize("trans_b", TRANS_B_VALUES)
@pytest.mark.parametrize("balance", BALANCE_VALUES)
@pytest.mark.parametrize("backend", [None, BackendType.CK, BackendType.HIPBLASLT, BackendType.TRITON])
@pytest.mark.parametrize(
"backend", [None, BackendType.CK, BackendType.HIPBLASLT, BackendType.TRITON, BackendType.FLYDSL]
)
@pytest.mark.parametrize("auto_tune", [False, True])
def test_grouped_gemm_fp8_tensorwise_quantized_tensor(
B, M, NK, ori_dtype, format, trans_b, balance, backend, auto_tune
):
"""TENSORWISE grouped_gemm with pre-quantized grouped/regular QuantizedTensor inputs."""
if backend == BackendType.FLYDSL and get_device_compute_capability() < (9, 5):
pytest.skip("FlyDSL fp8 grouped GEMM is gfx950-only")
Comment on lines +679 to +680
if backend == BackendType.TRITON and format == Format.HYBRID:
pytest.skip("TRITON backend not support HYBRID format currently")

Expand Down
Loading