diff --git a/atom/model_ops/fused_moe/mori_prepare_finalize.py b/atom/model_ops/fused_moe/mori_prepare_finalize.py index a8eaf355d0..a0fc889a0b 100644 --- a/atom/model_ops/fused_moe/mori_prepare_finalize.py +++ b/atom/model_ops/fused_moe/mori_prepare_finalize.py @@ -40,6 +40,8 @@ def init_mori_op( max_token_type_size: int, low_latency: bool = False, instance_id: int = 0, + scale_type_size: int = torch.float32.itemsize, + quant_type: str = "none", ) -> Any: """ Create a mori op instance. @@ -76,7 +78,7 @@ def init_mori_op( data_type=data_type, hidden_dim=hidden_dim, scale_dim=scale_dim, - scale_type_size=torch.float32.itemsize, + scale_type_size=scale_type_size, max_token_type_size=max_token_type_size, max_num_inp_token_per_rank=max_num_inp_token_per_rank, num_experts_per_rank=num_local_experts, @@ -86,6 +88,7 @@ def init_mori_op( kernel_type=kernel_type, gpu_per_node=gpu_per_node, rdma_block_num=rdma_block_num, + quant_type=quant_type, **({"num_qp_per_pe": 2} if low_latency else {}), ) mori_op = mori.ops.EpDispatchCombineOp(mori_config) @@ -108,11 +111,14 @@ def __init__( max_tokens_per_rank: int, num_dispatchers: int, use_fp8_dispatch: bool = False, + dispatch_quant_dtype: torch.dtype | None = None, quant_type=None, quant_dtype: torch.dtype = None, is_async: bool = False, tbo_mori_ops: list | None = None, low_latency: bool = False, + fp8_dispatch_decode_only: bool = False, + dispatch_quant_decode_only: bool = False, ): if not MORI_AVAILABLE: raise ImportError( @@ -124,11 +130,22 @@ def __init__( self._tbo_mori_ops = tbo_mori_ops # per-ubatch ops for TBO (IntraNode) self.num_dispatchers_ = num_dispatchers self.max_tokens_per_rank = max_tokens_per_rank - self.use_fp8_dispatch = use_fp8_dispatch + # dispatch quantization target. We keep `use_fp8_dispatch` for backward + # compatibility and map it to the generalized quantized-dispatch path. + if dispatch_quant_dtype is None and use_fp8_dispatch: + dispatch_quant_dtype = dtypes.fp8 + self.dispatch_quant_dtype = dispatch_quant_dtype self.quant_type = quant_type self.quant_dtype = quant_dtype self._is_async = is_async self._low_latency = low_latency + # Phase gating for quantized dispatch (fp8/fp4). Keep the legacy + # fp8-only flag for compatibility. + self.dispatch_quant_decode_only = ( + dispatch_quant_decode_only or fp8_dispatch_decode_only + ) + self.use_fp8_dispatch = self.dispatch_quant_dtype == dtypes.fp8 + self.fp8_dispatch_decode_only = self.dispatch_quant_decode_only @property def activation_format(self) -> mk.FusedMoEActivationFormat: @@ -160,6 +177,40 @@ def _get_dispatch_config(self): return 128, 16 return 64, 4 + def _dispatch_quant_dtype_now(self) -> torch.dtype | None: + """Dispatch quant dtype decision for the current forward (phase-aware).""" + if self.dispatch_quant_dtype is None: + return None + if self.dispatch_quant_decode_only and get_forward_context().context.is_prefill: + return None + return self.dispatch_quant_dtype + + def _quantize_dispatch_input( + self, a1: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor | None]: + dispatch_quant_dtype = self._dispatch_quant_dtype_now() + if dispatch_quant_dtype is None: + return a1, None + + from aiter import get_hip_quant + + quant_func = get_hip_quant(QuantType.per_1x32) + if dispatch_quant_dtype == dtypes.fp8: + # MXFP8 per_1x32 + e8m0 scale. + return quant_func( + a1, + quant_dtype=dtypes.fp8, + scale_type=dtypes.fp8_e8m0, + ) + if dispatch_quant_dtype == dtypes.fp4x2: + # MXFP4 per_1x32 + e8m0 scale. + return quant_func( + a1, + quant_dtype=dtypes.fp4x2, + scale_type=dtypes.fp8_e8m0, + ) + raise ValueError(f"Unsupported dispatch quant dtype: {dispatch_quant_dtype}") + # ---- Synchronous (non-TBO) path ---- def prepare( @@ -186,12 +237,7 @@ def prepare( assert ( not apply_router_weight_on_input ), "mori does not support apply_router_weight_on_input=True now." - scale = None - if self.use_fp8_dispatch: - from aiter import get_hip_quant - - quant_func = get_hip_quant(quant_type) - a1, scale = quant_func(a1, quant_dtype=dtypes.fp8) + a1, scale = self._quantize_dispatch_input(a1) block_num, warp_per_block = self._get_dispatch_config() @@ -254,19 +300,25 @@ def prepare_async( ), "mori does not support apply_router_weight_on_input=True now." scale = None - if self.use_fp8_dispatch: - from aiter import get_hip_quant - + dispatch_quant_dtype = self._dispatch_quant_dtype_now() + if dispatch_quant_dtype is not None: num_tokens = a1.shape[0] if num_tokens > 0: - quant_func = get_hip_quant(QuantType.per_1x128) - a1, scale = quant_func(a1, quant_dtype=dtypes.fp8) + a1, scale = self._quantize_dispatch_input(a1) else: hidden_size = a1.shape[1] if a1.dim() > 1 else 0 - a1 = torch.empty(a1.shape, dtype=dtypes.fp8, device=a1.device) + if dispatch_quant_dtype == dtypes.fp4x2: + packed_hidden_size = hidden_size // 2 + a1 = torch.empty( + (0, packed_hidden_size), + dtype=dtypes.fp4x2, + device=a1.device, + ) + else: + a1 = torch.empty(a1.shape, dtype=dispatch_quant_dtype, device=a1.device) scale = torch.empty( - (0, hidden_size // 128), - dtype=torch.float32, + (0, hidden_size // 32), + dtype=dtypes.fp8_e8m0, device=a1.device, ) diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index bd82373aef..bb1fe5b04c 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -import logging import os from abc import abstractmethod from dataclasses import dataclass @@ -21,6 +20,7 @@ QuantizationConfig, get_current_atom_config, ) +from atom.quant_spec import LayerQuantConfig, should_skip_online_quant from atom.model_loader.weight_utils import set_weight_attrs from atom.model_ops.base_config import QuantizeMethodBase from atom.model_ops.fused_moe.config import ( @@ -28,7 +28,6 @@ FusedMoEConfig, FusedMoEQuantConfig, fp8_w8a8_moe_quant_config, - mxfp4_w4a8_moe_quant_config, mxfp4_w4a16_moe_quant_config, ) from atom.model_ops.fused_moe.modular_kernel import ( @@ -50,34 +49,14 @@ per_tensor_dequantize, shuffle_weights, ) -from atom.plugin.vllm.moe import FusedMoEDecoratorForPluginMode -from atom.quant_spec import LayerQuantConfig, should_skip_online_quant -from atom.quantization.quark.utils import weight_dequant_fp8 from atom.utils import envs from atom.utils.custom_register import direct_register_custom_op -from atom.utils.decorators import mark_trace from atom.utils.forward_context import get_forward_context +from atom.utils.decorators import mark_trace from torch import nn from transformers import PretrainedConfig - -logger = logging.getLogger("atom") - - -class MoEActivationQuant(Enum): - BF16 = "bf16" - FP8 = "fp8" - FP4 = "fp4" - - @staticmethod - def from_model_config(a_quant_dtype: str | None) -> "MoEActivationQuant": - if a_quant_dtype is None or a_quant_dtype == "": - return MoEActivationQuant.BF16 - prefix = a_quant_dtype.split("_")[0] - if prefix == "fp8": - return MoEActivationQuant.FP8 - if prefix in ("fp4", "uint8"): - return MoEActivationQuant.FP4 - return MoEActivationQuant.BF16 +from atom.plugin.vllm.moe import FusedMoEDecoratorForPluginMode +from atom.quantization.quark.utils import weight_dequant_fp8 class FusedMoeWeightScaleSupported(Enum): @@ -201,63 +180,51 @@ def naive_multicast( return buffer -def pad_for_all_gather(x: torch.Tensor) -> Tuple[torch.Tensor, int]: - """Zero-pad ``x`` along dim 0 up to the uniform all-gather batch size. - - Every DP rank must contribute the same number of rows to the uniform - all-gather, so a short batch is padded up to ``graph_bs`` (scaled by the - per-sequence query length when decoding with MTP > 1). - - The padding MUST be zeros, never uninitialized memory: padded rows are - all-gathered across DP ranks and fed straight into the aiter fused-MoE - expert GEMM, where garbage values leak into real tokens' outputs. - Bisection traced a ~0.7pp GSM8K drop at large batch to a bare - ``torch.empty`` pad here; explicitly zeroing the pad rows fixes it. - - Returns the (possibly padded) tensor and the original row count so the - caller can unpad after reduce-scatter. - """ +def pad_for_all_gather(x: torch.Tensor): ctx = get_forward_context() max_batch_size = ctx.context.graph_bs if not ctx.context.is_prefill and ctx.attn_metadata is not None: + # For MTP > 1 max_batch_size *= ctx.attn_metadata.max_seqlen_q + dim = 0 + original_batch_size = x.shape[dim] + padded_x = x + if original_batch_size < max_batch_size: + padding_size = max_batch_size - original_batch_size - original_batch_size = x.shape[0] - padding_size = max_batch_size - original_batch_size - if padding_size <= 0: - return x, original_batch_size + padding_shape = list(x.shape) + padding_shape[dim] = padding_size + + padding = torch.empty(padding_shape, dtype=x.dtype, device=x.device) + # padding.zero_() + padded_x = torch.cat([x, padding], dim=dim) - padding_shape = list(x.shape) - padding_shape[0] = max_batch_size - padded_x = torch.empty(padding_shape, device=x.device, dtype=x.dtype) - padded_x[:original_batch_size, :].copy_(x) - # padded_x[original_batch_size:, :].zero_() return padded_x, original_batch_size -def all_gather_with_padding( - x: torch.Tensor, use_cag: bool = True -) -> Tuple[torch.Tensor, int]: +def all_gather_with_padding(x: torch.Tensor): padded_x, original_batch_size = pad_for_all_gather(x) # use_custom=True routes through CA IPC (outplace_all_gather). Default # use_custom=False falls back to torch.distributed.all_gather_into_tensor # (NCCL), whose WorkNCCL end-event recorded inside CUDAGraph capture is # later queried by the watchdog thread -> hipErrorCapturedEvent crash. - gathered_hidden_states = get_dp_group().all_gather( - padded_x, use_custom=use_cag, dim=0 - ) + gathered_hidden_states = get_dp_group().all_gather(padded_x, use_custom=True, dim=0) return gathered_hidden_states, original_batch_size def reduce_scatter_with_unpadding( x: torch.Tensor, original_batch_size: int ) -> torch.Tensor: + dim = 0 dp_group = get_dp_group() + + # scattered_output = dp_group.reduce_scatter(x, dim=dim) scattered_output = dp_group.reduce_scatter_tensor(x) - # Drop the rows that pad_for_all_gather zero-padded (padding is on dim 0). - if scattered_output.shape[0] > original_batch_size: - scattered_output = scattered_output[:original_batch_size] + if scattered_output.shape[dim] > original_batch_size: + slices = [slice(None)] * scattered_output.ndim + slices[dim] = slice(0, original_batch_size) + scattered_output = scattered_output[slices] return scattered_output @@ -319,6 +286,25 @@ def dp_gather_hidden_and_router( return hidden_states, router_logits, original_hidden_size, None +def run_dp_collective_with_tbo_overlap(fn): + """Run a DP collective, yielding to TBO comm stream when overlap is active.""" + from atom.utils.tbo.ubatching import tbo_active + + if not tbo_active(): + return fn() + + from atom.utils.tbo.ubatching import ( + tbo_switch_to_compute_sync, + tbo_yield_and_switch_from_compute_to_comm, + ) + + tbo_yield_and_switch_from_compute_to_comm() + try: + return fn() + finally: + tbo_switch_to_compute_sync() + + @torch_compile_guard() def get_max_tokens_across_dispatchers(input: torch.Tensor) -> int: return input.item() @@ -392,63 +378,101 @@ def _maybe_make_prepare_finalize( # For 1x128 quant, the scale dim for each token is hidden_dim // 128 scale_dim = 1 if quant_config.is_per_act_token else moe.hidden_dim // 128 - # Check if quant_dtype is an FP8 type from aiter import QuantType - fp8_dtypes = ( - torch.float8_e4m3fn, - torch.float8_e4m3fnuz, - torch.float8_e5m2, - torch.float8_e5m2fnuz, - ) - is_fp8 = quant_config.quant_dtype in fp8_dtypes - # For FP8: enable FP8 dispatch in Mori (quantize before communication) - # Note: per_Tensor quant doesn't support num_local_tokens, so we use per_Token - use_fp8_dispatch = is_fp8 - quant_type = None - if use_fp8_dispatch: - if quant_config.is_block_quantized: - quant_type = QuantType.per_1x128 - elif quant_config.is_per_act_token: - quant_type = QuantType.per_Token - - # For FP8: use FP8 dtype for communication - # For FP4/no quant: use bfloat16 - # mori_dtype = ( - # quant_config.quant_dtype - # if is_fp8 and quant_type is not None - # else torch.bfloat16 - # ) - # mori_dtype = torch.bfloat16 + import os as _os + + dispatch_quant_dtype: torch.dtype | None = None + dispatch_dtype_env = _os.environ.get("MORI_DISPATCH_DTYPE", "auto").strip().lower() + if dispatch_dtype_env in ("auto", ""): + model_dispatch_quant = MoEActivationQuant.from_model_config( + moe.a_quant_dtype if isinstance(moe.a_quant_dtype, str) else None + ) + if model_dispatch_quant == MoEActivationQuant.FP8: + dispatch_quant_dtype = dtypes.fp8 + elif model_dispatch_quant == MoEActivationQuant.FP4: + dispatch_quant_dtype = dtypes.fp4x2 + elif dispatch_dtype_env in ("none", "off", "bf16"): + dispatch_quant_dtype = None + elif dispatch_dtype_env == "fp8": + dispatch_quant_dtype = dtypes.fp8 + elif dispatch_dtype_env == "fp4": + dispatch_quant_dtype = dtypes.fp4x2 + else: + logger.warning( + "Invalid MORI_DISPATCH_DTYPE=%s. Supported: auto|none|bf16|fp8|fp4. Falling back to auto.", + dispatch_dtype_env, + ) + model_dispatch_quant = MoEActivationQuant.from_model_config( + moe.a_quant_dtype if isinstance(moe.a_quant_dtype, str) else None + ) + if model_dispatch_quant == MoEActivationQuant.FP8: + dispatch_quant_dtype = dtypes.fp8 + elif model_dispatch_quant == MoEActivationQuant.FP4: + dispatch_quant_dtype = dtypes.fp4x2 + + # Backward compatibility for the old fp8 env knob. + if _os.environ.get("MORI_FP8_DISPATCH", "0") == "1": + dispatch_quant_dtype = dtypes.fp8 + + if dispatch_quant_dtype is not None and moe.hidden_dim % 32 != 0: + logger.warning( + "Disable quantized MORI dispatch because hidden_dim=%d is not divisible by 32.", + moe.hidden_dim, + ) + dispatch_quant_dtype = None + + # phase gating for quantized dispatch. Keep the legacy fp8-only env. + _dispatch_quant_decode_only = ( + _os.environ.get( + "MORI_DISPATCH_DECODE_ONLY", + _os.environ.get("MORI_FP8_DISPATCH_DECODE_ONLY", "0"), + ) + == "1" + ) + + if dispatch_quant_dtype is not None: + scale_dim = moe.hidden_dim // 32 + scale_type_size = 1 # e8m0 scale + else: + scale_type_size = torch.float32.itemsize + + # Combine quantization mode is provided by the standardized MORI + # dispatch/combine API. + combine_quant_type = _os.environ.get( + "MORI_COMBINE_QUANT_TYPE", "none" + ).strip().lower() + if combine_quant_type not in {"none", "fp8_direct_cast", "fp8_blockwise"}: + logger.warning( + "Invalid MORI_COMBINE_QUANT_TYPE=%s. Supported: none|fp8_direct_cast|fp8_blockwise. Falling back to none.", + combine_quant_type, + ) + combine_quant_type = "none" all_to_all_args = dict( rank=all2all_manager.rank, num_ep_ranks=all2all_manager.world_size, - # quant_dtype=mori_dtype, - # We now use bfloat16 for mori - # TODO: To support quant quant_dtype=moe.in_dtype, token_hidden_size=moe.hidden_dim, scale_dim=scale_dim, - scale_type_size=torch.float32.itemsize, + scale_type_size=scale_type_size, max_num_tokens_per_dp_rank=16384, - # input_dtype=moe.in_dtype, input_dtype=moe.in_dtype, num_local_experts=moe.num_experts // all2all_manager.world_size, num_experts_per_token=moe.experts_per_token, gpu_per_node=moe.moe_parallel_config.local_ep_size, + quant_type=combine_quant_type, ) - from atom.config import get_current_atom_config from atom.utils.tbo.ubatching import tbo_enabled + from atom.config import get_current_atom_config handle = all2all_manager.get_handle(all_to_all_args) is_async = tbo_enabled() atom_config = get_current_atom_config() low_latency = getattr(atom_config, "enable_low_latency", False) - # We not use quant for mori now - use_fp8_dispatch = False - quant_type = None + use_fp8_dispatch = dispatch_quant_dtype == dtypes.fp8 + quant_type = QuantType.per_1x32 if use_fp8_dispatch else None common_args = dict( rank=all2all_manager.rank, @@ -461,14 +485,16 @@ def _maybe_make_prepare_finalize( gpu_per_node=moe.moe_parallel_config.local_ep_size, data_type_itemsize=moe.in_dtype.itemsize, max_token_type_size=moe.in_dtype.itemsize, + scale_type_size=scale_type_size, + quant_type=combine_quant_type, ) tbo_mori_ops = None sync_handle = handle # IntraNode handle for prefill (sync path) if is_async: from atom.model_ops.fused_moe.mori_prepare_finalize import ( - _NUM_TBO_UBATCHES, init_mori_op, + _NUM_TBO_UBATCHES, ) tbo_mori_ops = [ @@ -485,10 +511,13 @@ def _maybe_make_prepare_finalize( max_tokens_per_rank=moe.max_num_tokens, num_dispatchers=all2all_manager.world_size, use_fp8_dispatch=use_fp8_dispatch, + dispatch_quant_dtype=dispatch_quant_dtype, quant_type=quant_type, is_async=is_async, tbo_mori_ops=tbo_mori_ops, low_latency=low_latency, + fp8_dispatch_decode_only=_dispatch_quant_decode_only, + dispatch_quant_decode_only=_dispatch_quant_decode_only, ) return prepare_finalize @@ -803,7 +832,10 @@ def __init__(self, quant_config: LayerQuantConfig, moe: FusedMoEConfig): self.use_triton = gfx.startswith("gfx94") or ( gfx.startswith("gfx95") and envs.ATOM_USE_TRITON_GEMM ) - self.act_quant = MoEActivationQuant.from_model_config(moe.a_quant_dtype) + if self.use_triton: + from atom.model_ops.utils import has_triton_kernels + + assert has_triton_kernels(), "triton_kernels is not installed" def create_weights( self, @@ -837,7 +869,7 @@ def create_weights( w13_weight = atom_parameter( torch.empty( num_experts, - 2 * intermediate_size_per_partition_after_pad, # TP included + 2 * intermediate_size_per_partition_after_pad, hidden_size // 2, dtype=weight_dtype, ) @@ -877,7 +909,7 @@ def create_weights( torch.empty( num_experts, hidden_size, - intermediate_size_per_partition_after_pad // 2, # TP included + intermediate_size_per_partition_after_pad // 2, dtype=weight_dtype, ) ) @@ -909,9 +941,6 @@ def create_weights( else: layer.register_parameter("w2_bias", None) - layer.w13_swizzle_layout = None - layer.w2_swizzle_layout = None - if self.static_input_scales: w13_input_scale = atom_parameter( torch.ones(num_experts, dtype=torch.float32) @@ -943,10 +972,14 @@ def process_weights_after_loading(self, layer): ) if self.use_triton: - from atom.config import get_current_atom_config + import dataclasses + from atom.model_ops.fused_moe_triton import _swizzle_mxfp4 - atom_config = get_current_atom_config() + try: + from triton_kernels.matmul import FlexCtx, PrecisionConfig + except ImportError: + from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig # Stash dense (pre-swizzle) shared-expert weights so the always-on # shared expert can be evaluated by a standalone dense MXFP4 GEMM @@ -990,26 +1023,37 @@ def process_weights_after_loading(self, layer): ) = _swizzle_mxfp4( layer.w13_weight.view(torch.uint8), layer.w13_weight_scale, + ) + w2_weight, w2_flex, w2_scale = _swizzle_mxfp4( layer.w2_weight.view(torch.uint8), layer.w2_weight_scale, - "mx4", - self.intermediate_size * 2, # N_1, - self.hidden_size, # K_1, - self.hidden_size, # N_2, - self.intermediate_size, # K_2, - atom_config.tensor_parallel_size, - act_quant=self.act_quant, ) + + _pc_field_names = {f.name for f in dataclasses.fields(PrecisionConfig)} + + def _build_precision_config(scale, flex): + kwargs = {"flex_ctx": FlexCtx(rhs_data=flex)} + if "weight_scale" in _pc_field_names: + kwargs["weight_scale"] = scale + else: + # New triton_kernels API renamed `weight_scale` → `b_mx_scale` + # and now requires the microblock size to be set explicitly. + from triton_kernels.numerics_details.mxfp import MXFP_BLOCK_SIZE + + kwargs["b_mx_scale"] = scale + kwargs["b_microblock_size"] = int(MXFP_BLOCK_SIZE) + return PrecisionConfig(**kwargs) + + self.w13_precision_config = _build_precision_config(w13_scale, w13_flex) + self.w2_precision_config = _build_precision_config(w2_scale, w2_flex) del layer.w13_weight del layer.w2_weight del layer.w13_weight_scale del layer.w2_weight_scale layer.w13_weight = w13_weight layer.w2_weight = w2_weight - layer.w13_weight_scale = w13_scale - layer.w2_weight_scale = w2_scale - layer.w13_swizzle_layout = w13_swizzle_layout - layer.w2_swizzle_layout = w2_swizzle_layout + layer.w13_weight_scale = None + layer.w2_weight_scale = None return # shuffle weight @@ -1050,19 +1094,6 @@ def process_weights_after_loading(self, layer): def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: - - a1_scale = getattr(layer, "w13_input_scale", None) - a2_scale = getattr(layer, "w2_input_scale", None) - - if self.act_quant == MoEActivationQuant.FP8: - return mxfp4_w4a8_moe_quant_config( - a1_scale=a1_scale, - a2_scale=a2_scale, - w1_bias=layer.w13_bias, - w2_bias=layer.w2_bias, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - ) return mxfp4_w4a16_moe_quant_config( w1_bias=layer.w13_bias, w2_bias=layer.w2_bias, @@ -1092,8 +1123,9 @@ def apply( ) -> torch.Tensor: if self.use_triton: from atom.model_ops.fused_moe_triton import ( - triton_kernel_fused_experts, triton_kernel_moe_forward, + triton_kernel_fused_experts, + fused_routing_from_topk_triton, ) # Check if the model needs custom routing that triton routing() @@ -1106,37 +1138,38 @@ def apply( ) if needs_custom_routing: - # custom routing -- set for deepseek routing n expts act, for grouped topk - n_expts_act = top_k - - # custom routing - from aiter.ops.triton.moe.moe_routing.routing import ( # grouped topk included - routing, - ) - - routing_data, gather_idx, scatter_idx = routing( - router_logits, - n_expts_act, - score_mode=scoring_func, - bias=( - e_score_correction_bias.to(torch.float32) - if e_score_correction_bias is not None - else None - ), - renorm=renormalize, - routed_scaling_factor=layer.routed_scaling_factor, + # Use ATOM's full-featured select_experts for routing, + # then triton matmul_ogs for the actual MoE computation. + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, use_grouped_topk=use_grouped_topk, - num_expert_group=num_expert_group, + top_k=top_k, + renormalize=renormalize, topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + num_fused_shared_experts=layer.num_fused_shared_experts, + routed_scaling_factor=layer.routed_scaling_factor, ) - # Routed-only gate count (no shared-expert widening). - n_expts_act = routing_data.n_expts_act + n_expts_act = topk_weights.shape[1] # Convert to triton routing data structures - num_tokens, n_expts_tot = router_logits.shape + if expert_map is not None: + # local_num_experts already includes fused shared experts + # (added at FusedMoE.__init__ line ~2056). + n_expts_tot = layer.local_num_experts + else: + n_expts_tot = router_logits.shape[-1] + if global_num_experts > 0: + n_expts_tot = global_num_experts + n_expts_tot = n_expts_tot + layer.num_fused_shared_experts - if global_num_experts > 0: - n_expts_tot = global_num_experts + routing_data, gather_idx, scatter_idx = fused_routing_from_topk_triton( + topk_weights, topk_ids, n_expts_tot, expert_map=expert_map + ) output = torch.empty_like(x) _moe_result = triton_kernel_fused_experts( @@ -1149,34 +1182,21 @@ def apply( scatter_idx, topk=n_expts_act, activation=activation, - w13_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - w13_swizzle_layout=layer.w13_swizzle_layout, - w2_swizzle_layout=layer.w2_swizzle_layout, - a13_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, + w13_precision_config=self.w13_precision_config, + w2_precision_config=self.w2_precision_config, w1_bias=layer.w13_bias, w2_bias=layer.w2_bias, swiglu_limit=getattr(layer, "swiglu_limit", 0.0), apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=n_expts_tot, expert_map=expert_map, - act_quant=self.act_quant, ) - - # Always-on shared expert(s) via a standalone dense GEMM, - # added to the routed output before the TP all-reduce. - if layer.num_fused_shared_experts > 0: - _moe_result = _moe_result + self._apply_shared_experts_dense( - layer, x, activation - ) return _moe_result assert ( fused_shared_experts_scoring_func is None ), "triton kernel does not support fused shared experts func" - # Takes directly from model dtype in config.json return triton_kernel_moe_forward( x, layer.w13_weight, @@ -1185,19 +1205,14 @@ def apply( topk=top_k, renormalize=renormalize, activation=activation, - w13_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - w13_swizzle_layout=layer.w13_swizzle_layout, - w2_swizzle_layout=layer.w2_swizzle_layout, - a13_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, + w13_precision_config=self.w13_precision_config, + w2_precision_config=self.w2_precision_config, w1_bias=layer.w13_bias, w2_bias=layer.w2_bias, swiglu_limit=getattr(layer, "swiglu_limit", 7.0), expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, - act_quant=self.act_quant, ) topk_weights, topk_ids = FusedMoE.select_experts( @@ -1345,7 +1360,6 @@ def _shared_expert_gemm(act, weight, weight_scale): shared_out = out_e if shared_out is None else shared_out + out_e return shared_out - # Refer to CompressedTensorsW8A8Fp8MoEMethod in vllm class CompressedTensorsFp8MoEMethod(FusedMoEMethodBase): def __init__(self, quant_config: LayerQuantConfig, moe: FusedMoEConfig): @@ -2232,7 +2246,6 @@ def __init__( ) self.layer_quant_config = layer_quant_config self.has_bias = has_bias - # Note: here we guard against accessing the TP and DP groups when # uninitialized (this happens when testing) # self.tp_size = 1 @@ -2339,18 +2352,6 @@ def __init__( self.e_score_correction_bias = e_score_correction_bias self.activation = activation - self.use_chunked = get_dp_group().world_size > 1 - - try: - a_quant_dtype = ( - config.quantization_config.get("global_quant_config", "") - .get("input_tensors", "") - .get("dtype", "") - ) - except AttributeError: - # global quant config does not exist, no activation loaded - a_quant_dtype = None - moe = FusedMoEConfig( num_experts=self.global_num_experts, experts_per_token=self.top_k, @@ -2358,7 +2359,6 @@ def __init__( num_local_experts=self.local_num_experts, moe_parallel_config=self.moe_parallel_config, in_dtype=atom_config.torch_dtype, - a_quant_dtype=a_quant_dtype, max_num_tokens=atom_config.max_num_batched_tokens, has_bias=self.has_bias, # is_act_and_mul=True, @@ -3237,6 +3237,20 @@ def select_experts( f"Unsupported scoring function for non-grouped topk: {scoring_func}" ) + # [balance exp] env-gated: force perfectly balanced round-robin routing + # over routed experts (ignore accuracy) to isolate load-imbalance from + # launch/sync overhead in the EP dispatch/combine path. + import os as _os + + if _os.environ.get("MORI_FORCE_BALANCED_ROUTING", "0") == "1": + _nt, _k = topk_ids.shape + _ne = router_logits.shape[-1] + topk_ids = ( + (torch.arange(_nt * _k, device=topk_ids.device) % _ne) + .reshape(_nt, _k) + .to(topk_ids.dtype) + ) + return topk_weights, topk_ids def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): @@ -3263,30 +3277,17 @@ def forward_impl_graph( ctx = get_forward_context() dp_group = get_dp_group() dp_eager_mode = not ctx.context.dp_uniform_decode - - from atom.utils.tbo.ubatching import tbo_active - - _tbo = tbo_active() - if _tbo: - from atom.utils.tbo.ubatching import ( - tbo_switch_to_compute_sync, - tbo_yield_and_switch_from_compute_to_comm, - ) - - tbo_yield_and_switch_from_compute_to_comm() - ( hidden_states, router_logits, original_hidden_size, sizes, - ) = dp_gather_hidden_and_router( - hidden_states, router_logits, dp_eager_mode, ctx, dp_group + ) = run_dp_collective_with_tbo_overlap( + lambda: dp_gather_hidden_and_router( + hidden_states, router_logits, dp_eager_mode, ctx, dp_group + ) ) - if _tbo: - tbo_switch_to_compute_sync() - # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, @@ -3309,18 +3310,16 @@ def forward_impl_graph( # Use reduce_scatter when DP > 1 but not using mori all2all kernels if use_dp_gather_scatter: - if _tbo: - tbo_yield_and_switch_from_compute_to_comm() if dp_eager_mode: - final_hidden_states = reduce_scatterv( - final_hidden_states, sizes, dp_group + final_hidden_states = run_dp_collective_with_tbo_overlap( + lambda: reduce_scatterv(final_hidden_states, sizes, dp_group) ) else: - final_hidden_states = reduce_scatter_with_unpadding( - final_hidden_states, original_hidden_size + final_hidden_states = run_dp_collective_with_tbo_overlap( + lambda: reduce_scatter_with_unpadding( + final_hidden_states, original_hidden_size + ) ) - if _tbo: - tbo_switch_to_compute_sync() if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): # Default set to False. (May have to add shared expert outputs.)