Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 68 additions & 16 deletions atom/model_ops/fused_moe/mori_prepare_finalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def init_mori_op(
max_token_type_size: int,
low_latency: bool = False,
instance_id: int = 0,
scale_type_size: int = torch.float32.itemsize,
quant_type: str = "none",
) -> Any:
"""
Create a mori op instance.
Expand Down Expand Up @@ -76,7 +78,7 @@ def init_mori_op(
data_type=data_type,
hidden_dim=hidden_dim,
scale_dim=scale_dim,
scale_type_size=torch.float32.itemsize,
scale_type_size=scale_type_size,
max_token_type_size=max_token_type_size,
max_num_inp_token_per_rank=max_num_inp_token_per_rank,
num_experts_per_rank=num_local_experts,
Expand All @@ -86,6 +88,7 @@ def init_mori_op(
kernel_type=kernel_type,
gpu_per_node=gpu_per_node,
rdma_block_num=rdma_block_num,
quant_type=quant_type,
**({"num_qp_per_pe": 2} if low_latency else {}),
)
mori_op = mori.ops.EpDispatchCombineOp(mori_config)
Expand All @@ -108,11 +111,14 @@ def __init__(
max_tokens_per_rank: int,
num_dispatchers: int,
use_fp8_dispatch: bool = False,
dispatch_quant_dtype: torch.dtype | None = None,
quant_type=None,
quant_dtype: torch.dtype = None,
is_async: bool = False,
tbo_mori_ops: list | None = None,
low_latency: bool = False,
fp8_dispatch_decode_only: bool = False,
dispatch_quant_decode_only: bool = False,
):
if not MORI_AVAILABLE:
raise ImportError(
Expand All @@ -124,11 +130,22 @@ def __init__(
self._tbo_mori_ops = tbo_mori_ops # per-ubatch ops for TBO (IntraNode)
self.num_dispatchers_ = num_dispatchers
self.max_tokens_per_rank = max_tokens_per_rank
self.use_fp8_dispatch = use_fp8_dispatch
# dispatch quantization target. We keep `use_fp8_dispatch` for backward
# compatibility and map it to the generalized quantized-dispatch path.
if dispatch_quant_dtype is None and use_fp8_dispatch:
dispatch_quant_dtype = dtypes.fp8
self.dispatch_quant_dtype = dispatch_quant_dtype
self.quant_type = quant_type
self.quant_dtype = quant_dtype
self._is_async = is_async
self._low_latency = low_latency
# Phase gating for quantized dispatch (fp8/fp4). Keep the legacy
# fp8-only flag for compatibility.
self.dispatch_quant_decode_only = (
dispatch_quant_decode_only or fp8_dispatch_decode_only
)
self.use_fp8_dispatch = self.dispatch_quant_dtype == dtypes.fp8
self.fp8_dispatch_decode_only = self.dispatch_quant_decode_only

@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
Expand Down Expand Up @@ -160,6 +177,40 @@ def _get_dispatch_config(self):
return 128, 16
return 64, 4

def _dispatch_quant_dtype_now(self) -> torch.dtype | None:
"""Dispatch quant dtype decision for the current forward (phase-aware)."""
if self.dispatch_quant_dtype is None:
return None
if self.dispatch_quant_decode_only and get_forward_context().context.is_prefill:
return None
return self.dispatch_quant_dtype

def _quantize_dispatch_input(
self, a1: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor | None]:
dispatch_quant_dtype = self._dispatch_quant_dtype_now()
if dispatch_quant_dtype is None:
return a1, None

from aiter import get_hip_quant

quant_func = get_hip_quant(QuantType.per_1x32)
if dispatch_quant_dtype == dtypes.fp8:
# MXFP8 per_1x32 + e8m0 scale.
return quant_func(
a1,
quant_dtype=dtypes.fp8,
scale_type=dtypes.fp8_e8m0,
)
if dispatch_quant_dtype == dtypes.fp4x2:
# MXFP4 per_1x32 + e8m0 scale.
return quant_func(
a1,
quant_dtype=dtypes.fp4x2,
scale_type=dtypes.fp8_e8m0,
)
raise ValueError(f"Unsupported dispatch quant dtype: {dispatch_quant_dtype}")

# ---- Synchronous (non-TBO) path ----

def prepare(
Expand All @@ -186,12 +237,7 @@ def prepare(
assert (
not apply_router_weight_on_input
), "mori does not support apply_router_weight_on_input=True now."
scale = None
if self.use_fp8_dispatch:
from aiter import get_hip_quant

quant_func = get_hip_quant(quant_type)
a1, scale = quant_func(a1, quant_dtype=dtypes.fp8)
a1, scale = self._quantize_dispatch_input(a1)

block_num, warp_per_block = self._get_dispatch_config()

Expand Down Expand Up @@ -254,19 +300,25 @@ def prepare_async(
), "mori does not support apply_router_weight_on_input=True now."

scale = None
if self.use_fp8_dispatch:
from aiter import get_hip_quant

dispatch_quant_dtype = self._dispatch_quant_dtype_now()
if dispatch_quant_dtype is not None:
num_tokens = a1.shape[0]
if num_tokens > 0:
quant_func = get_hip_quant(QuantType.per_1x128)
a1, scale = quant_func(a1, quant_dtype=dtypes.fp8)
a1, scale = self._quantize_dispatch_input(a1)
else:
hidden_size = a1.shape[1] if a1.dim() > 1 else 0
a1 = torch.empty(a1.shape, dtype=dtypes.fp8, device=a1.device)
if dispatch_quant_dtype == dtypes.fp4x2:
packed_hidden_size = hidden_size // 2
a1 = torch.empty(
(0, packed_hidden_size),
dtype=dtypes.fp4x2,
device=a1.device,
)
else:
a1 = torch.empty(a1.shape, dtype=dispatch_quant_dtype, device=a1.device)
scale = torch.empty(
(0, hidden_size // 128),
dtype=torch.float32,
(0, hidden_size // 32),
dtype=dtypes.fp8_e8m0,
device=a1.device,
)

Expand Down
Loading