From 7d46cd240736c8edcf881ea6304dc491f13743bc Mon Sep 17 00:00:00 2001 From: JiaoliangYu Date: Wed, 17 Jun 2026 13:22:52 +0800 Subject: [PATCH 01/11] feat: env-gated per-rank CPU/NUMA pinning for workers Pin each worker to a contiguous core range keyed on its rank at the very top of AsyncIOProc.__init__, before any large allocation, so Linux first-touch also places memory on the local NUMA node. Gated by ATOM_CPU_AFFINITY so baseline vs pinned A/B needs no code change. Co-Authored-By: Claude Opus 4.7 --- atom/model_engine/async_proc.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/atom/model_engine/async_proc.py b/atom/model_engine/async_proc.py index 49669d441f..ef719224de 100644 --- a/atom/model_engine/async_proc.py +++ b/atom/model_engine/async_proc.py @@ -15,6 +15,7 @@ import logging import multiprocessing +import os import pickle import queue import threading @@ -70,6 +71,19 @@ def __init__( *args, **kwargs, ): + # Per-rank CPU/NUMA pinning. Must run before any large allocation so + # Linux first-touch places memory on the local NUMA node (implicit + # membind without libnuma). Gated by env so baseline/pinned A/B is free. + if os.environ.get("ATOM_CPU_AFFINITY", "0") == "1": + n_cpu = os.cpu_count() + world = int(os.environ.get("ATOM_WORLD_SIZE", "8")) + per = n_cpu // world + cores = set(range(rank * per, rank * per + per)) + os.sched_setaffinity(0, cores) + logger.info( + f"AsyncIOProc({label}): pinned to cores " + f"{rank * per}-{rank * per + per - 1} (NUMA node {rank // (world // 2)})" + ) self.label = f"AsyncIOProc({label})" self.io_addrs = io_addrs self.io_queues = queue.Queue(), queue.Queue() From 24654378553b642ef5e982d6bdbc4389f57fc39d Mon Sep 17 00:00:00 2001 From: JiaoliangYu Date: Wed, 17 Jun 2026 15:34:46 +0800 Subject: [PATCH 02/11] cpu affinity --- .github/benchmark/models.json | 4 ++++ atom/model_engine/async_proc.py | 28 +++++++++++++++++++--------- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/.github/benchmark/models.json b/.github/benchmark/models.json index 28e0db1530..d864ee5ef5 100644 --- a/.github/benchmark/models.json +++ b/.github/benchmark/models.json @@ -65,7 +65,11 @@ "label": "DPA TBO", "suffix": "-dpa-tbo", "extra_args": "--enable-dp-attention --enable-tbo", +<<<<<<< HEAD "env_vars": "GPU_MAX_HW_QUEUES=5", +======= + "env_vars": "GPU_MAX_HW_QUEUES=5\nATOM_CPU_AFFINITY=1", +>>>>>>> 55624079 (fix: derive global GPU rank for affinity from config, not TP-local rank) "conc_min": 256, "conc_max": 1024 }, diff --git a/atom/model_engine/async_proc.py b/atom/model_engine/async_proc.py index ef719224de..b496a16941 100644 --- a/atom/model_engine/async_proc.py +++ b/atom/model_engine/async_proc.py @@ -74,16 +74,26 @@ def __init__( # Per-rank CPU/NUMA pinning. Must run before any large allocation so # Linux first-touch places memory on the local NUMA node (implicit # membind without libnuma). Gated by env so baseline/pinned A/B is free. + # The global GPU index is dp_rank * tp_size + tp_rank (see + # engine_core_mgr GPU assignment); the `rank` arg is only the TP-local + # rank, which is always 0 under DP-attention (tp_size == 1 per engine). if os.environ.get("ATOM_CPU_AFFINITY", "0") == "1": - n_cpu = os.cpu_count() - world = int(os.environ.get("ATOM_WORLD_SIZE", "8")) - per = n_cpu // world - cores = set(range(rank * per, rank * per + per)) - os.sched_setaffinity(0, cores) - logger.info( - f"AsyncIOProc({label}): pinned to cores " - f"{rank * per}-{rank * per + per - 1} (NUMA node {rank // (world // 2)})" - ) + try: + cfg = args[0] + tp_size = cfg.tensor_parallel_size + dp_size = cfg.parallel_config.data_parallel_size + dp_rank = cfg.parallel_config.data_parallel_rank + world = dp_size * tp_size + gpu = dp_rank * tp_size + rank + per = os.cpu_count() // world + lo = gpu * per + os.sched_setaffinity(0, set(range(lo, lo + per))) + logger.info( + f"AsyncIOProc({label}): gpu={gpu}/{world} " + f"pinned to cores {lo}-{lo + per - 1}" + ) + except Exception as e: + logger.warning(f"AsyncIOProc({label}): CPU affinity skipped: {e}") self.label = f"AsyncIOProc({label})" self.io_addrs = io_addrs self.io_queues = queue.Queue(), queue.Queue() From 7836173260fc7508616217556969fc650c564795 Mon Sep 17 00:00:00 2001 From: JiaoliangYu Date: Wed, 17 Jun 2026 15:54:59 +0800 Subject: [PATCH 03/11] resolve confict --- .github/benchmark/models.json | 4 ---- 1 file changed, 4 deletions(-) diff --git a/.github/benchmark/models.json b/.github/benchmark/models.json index d864ee5ef5..92301c932e 100644 --- a/.github/benchmark/models.json +++ b/.github/benchmark/models.json @@ -65,11 +65,7 @@ "label": "DPA TBO", "suffix": "-dpa-tbo", "extra_args": "--enable-dp-attention --enable-tbo", -<<<<<<< HEAD - "env_vars": "GPU_MAX_HW_QUEUES=5", -======= "env_vars": "GPU_MAX_HW_QUEUES=5\nATOM_CPU_AFFINITY=1", ->>>>>>> 55624079 (fix: derive global GPU rank for affinity from config, not TP-local rank) "conc_min": 256, "conc_max": 1024 }, From 2417b95942820f24bee6e6a0c36d13f355737772 Mon Sep 17 00:00:00 2001 From: JiaoliangYu Date: Wed, 17 Jun 2026 16:10:17 +0800 Subject: [PATCH 04/11] ci: lock GPU clock to 2400MHz (determinism) during benchmarks Add amd-smi performance-determinism lock before the benchmark/server runs and an always() unlock to AUTO before container teardown, in both the main benchmark job and the regression-rerun job. The lock is driver-level and persists across jobs on the bare-metal runner, so the unlock must run even on failure. Co-Authored-By: Claude Opus 4.7 --- .github/workflows/atom-benchmark.yaml | 29 +++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/.github/workflows/atom-benchmark.yaml b/.github/workflows/atom-benchmark.yaml index a542879f06..4b1a104e60 100644 --- a/.github/workflows/atom-benchmark.yaml +++ b/.github/workflows/atom-benchmark.yaml @@ -198,6 +198,15 @@ jobs: echo "docker_image=${DOCKER_IMAGE}" >> $GITHUB_OUTPUT echo "Docker: ${DOCKER_IMAGE}" + - name: Lock GPU clock (determinism) + run: | + # Container is --privileged so /sys is writable and amd-smi can set + # performance determinism (the only working clock-lock on MI355X; + # -l HIGH / setperflevel / sysfs writes are all rejected). + GPUS=$(docker exec atom-benchmark bash -lc "amd-smi list --csv 2>/dev/null | tail -n +2 | cut -d, -f1 | tr '\n' ' '") + echo "Locking GPUs [$GPUS] to 2400 MHz" + docker exec atom-benchmark bash -lc "amd-smi set -g $GPUS -d 2400" + - name: Run benchmark timeout-minutes: 80 env: @@ -296,6 +305,14 @@ jobs: name: benchmark-${{ env.RESULT_FILENAME }} path: ${{ env.RESULT_FILENAME }}.json + - name: Unlock GPU clock + if: always() + run: | + # Self-hosted bare-metal runner: the clock lock is driver-level and + # persists across jobs, so it MUST be restored even on failure. + docker exec atom-benchmark bash -lc \ + 'amd-smi set -g $(amd-smi list --csv 2>/dev/null | tail -n +2 | cut -d, -f1 | tr "\n" " ") -l AUTO' || true + - name: Clean Up if: always() run: | @@ -523,6 +540,12 @@ jobs: hf-token: ${{ secrets.AMD_HF_TOKEN }} download-required: "false" + - name: Lock GPU clock (determinism) + run: | + GPUS=$(docker exec atom-regression bash -lc "amd-smi list --csv 2>/dev/null | tail -n +2 | cut -d, -f1 | tr '\n' ' '") + echo "Locking GPUs [$GPUS] to 2400 MHz" + docker exec atom-regression bash -lc "amd-smi set -g $GPUS -d 2400" + - name: Launch server env: SERVER_ARGS: ${{ matrix.cell.server_args }} @@ -648,6 +671,12 @@ jobs: path: regression-*.json if-no-files-found: ignore + - name: Unlock GPU clock + if: always() + run: | + docker exec atom-regression bash -lc \ + 'amd-smi set -g $(amd-smi list --csv 2>/dev/null | tail -n +2 | cut -d, -f1 | tr "\n" " ") -l AUTO' || true + - name: Clean Up if: always() run: | From 1317bf0d4fdc9b3e79959615359aa87e04a69cf9 Mon Sep 17 00:00:00 2001 From: JiaoliangYu Date: Wed, 17 Jun 2026 17:36:05 +0800 Subject: [PATCH 05/11] sync base from container 0926941f (moe.py + mori_prepare_finalize.py) --- atom/model_ops/moe.py | 356 ++++++++++++------------------------------ 1 file changed, 99 insertions(+), 257 deletions(-) diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index ad45981e36..ff4bfc12ee 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -import logging import os from abc import abstractmethod from dataclasses import dataclass @@ -14,13 +13,14 @@ from aiter.fused_moe import fused_moe from aiter.jit.utils.chip_info import get_gfx from aiter.jit.utils.torch_guard import torch_compile_guard -from aiter.ops.flydsl.moe_common import GateMode -from aiter.ops.shuffle import shuffle_scale, shuffle_weight +from aiter.ops.shuffle import shuffle_weight, shuffle_scale from atom.config import ( Config, QuantizationConfig, get_current_atom_config, ) +from aiter.ops.flydsl.moe_common import GateMode +from atom.quant_spec import LayerQuantConfig, should_skip_online_quant from atom.model_loader.weight_utils import set_weight_attrs from atom.model_ops.base_config import QuantizeMethodBase from atom.model_ops.fused_moe.config import ( @@ -28,7 +28,6 @@ FusedMoEConfig, FusedMoEQuantConfig, fp8_w8a8_moe_quant_config, - mxfp4_w4a8_moe_quant_config, mxfp4_w4a16_moe_quant_config, ) from atom.model_ops.fused_moe.modular_kernel import ( @@ -50,34 +49,14 @@ per_tensor_dequantize, shuffle_weights, ) -from atom.plugin.vllm.moe import FusedMoEDecoratorForPluginMode -from atom.quant_spec import LayerQuantConfig, should_skip_online_quant -from atom.quantization.quark.utils import weight_dequant_fp8 from atom.utils import envs from atom.utils.custom_register import direct_register_custom_op -from atom.utils.decorators import mark_trace from atom.utils.forward_context import get_forward_context +from atom.utils.decorators import mark_trace from torch import nn from transformers import PretrainedConfig - -logger = logging.getLogger("atom") - - -class MoEActivationQuant(Enum): - BF16 = "bf16" - FP8 = "fp8" - FP4 = "fp4" - - @staticmethod - def from_model_config(a_quant_dtype: str | None) -> "MoEActivationQuant": - if a_quant_dtype is None or a_quant_dtype == "": - return MoEActivationQuant.BF16 - prefix = a_quant_dtype.split("_")[0] - if prefix == "fp8": - return MoEActivationQuant.FP8 - if prefix in ("fp4", "uint8"): - return MoEActivationQuant.FP4 - return MoEActivationQuant.BF16 +from atom.plugin.vllm.moe import FusedMoEDecoratorForPluginMode +from atom.quantization.quark.utils import weight_dequant_fp8 class FusedMoeWeightScaleSupported(Enum): @@ -201,63 +180,51 @@ def naive_multicast( return buffer -def pad_for_all_gather(x: torch.Tensor) -> Tuple[torch.Tensor, int]: - """Zero-pad ``x`` along dim 0 up to the uniform all-gather batch size. - - Every DP rank must contribute the same number of rows to the uniform - all-gather, so a short batch is padded up to ``graph_bs`` (scaled by the - per-sequence query length when decoding with MTP > 1). - - The padding MUST be zeros, never uninitialized memory: padded rows are - all-gathered across DP ranks and fed straight into the aiter fused-MoE - expert GEMM, where garbage values leak into real tokens' outputs. - Bisection traced a ~0.7pp GSM8K drop at large batch to a bare - ``torch.empty`` pad here; explicitly zeroing the pad rows fixes it. - - Returns the (possibly padded) tensor and the original row count so the - caller can unpad after reduce-scatter. - """ +def pad_for_all_gather(x: torch.Tensor): ctx = get_forward_context() max_batch_size = ctx.context.graph_bs if not ctx.context.is_prefill and ctx.attn_metadata is not None: + # For MTP > 1 max_batch_size *= ctx.attn_metadata.max_seqlen_q + dim = 0 + original_batch_size = x.shape[dim] + padded_x = x + if original_batch_size < max_batch_size: + padding_size = max_batch_size - original_batch_size - original_batch_size = x.shape[0] - padding_size = max_batch_size - original_batch_size - if padding_size <= 0: - return x, original_batch_size + padding_shape = list(x.shape) + padding_shape[dim] = padding_size + + padding = torch.empty(padding_shape, dtype=x.dtype, device=x.device) + # padding.zero_() + padded_x = torch.cat([x, padding], dim=dim) - padding_shape = list(x.shape) - padding_shape[0] = max_batch_size - padded_x = torch.empty(padding_shape, device=x.device, dtype=x.dtype) - padded_x[:original_batch_size, :].copy_(x) - # padded_x[original_batch_size:, :].zero_() return padded_x, original_batch_size -def all_gather_with_padding( - x: torch.Tensor, use_cag: bool = True -) -> Tuple[torch.Tensor, int]: +def all_gather_with_padding(x: torch.Tensor): padded_x, original_batch_size = pad_for_all_gather(x) # use_custom=True routes through CA IPC (outplace_all_gather). Default # use_custom=False falls back to torch.distributed.all_gather_into_tensor # (NCCL), whose WorkNCCL end-event recorded inside CUDAGraph capture is # later queried by the watchdog thread -> hipErrorCapturedEvent crash. - gathered_hidden_states = get_dp_group().all_gather( - padded_x, use_custom=use_cag, dim=0 - ) + gathered_hidden_states = get_dp_group().all_gather(padded_x, use_custom=True, dim=0) return gathered_hidden_states, original_batch_size def reduce_scatter_with_unpadding( x: torch.Tensor, original_batch_size: int ) -> torch.Tensor: + dim = 0 dp_group = get_dp_group() + + # scattered_output = dp_group.reduce_scatter(x, dim=dim) scattered_output = dp_group.reduce_scatter_tensor(x) - # Drop the rows that pad_for_all_gather zero-padded (padding is on dim 0). - if scattered_output.shape[0] > original_batch_size: - scattered_output = scattered_output[:original_batch_size] + if scattered_output.shape[dim] > original_batch_size: + slices = [slice(None)] * scattered_output.ndim + slices[dim] = slice(0, original_batch_size) + scattered_output = scattered_output[slices] return scattered_output @@ -438,8 +405,8 @@ def _maybe_make_prepare_finalize( num_experts_per_token=moe.experts_per_token, gpu_per_node=moe.moe_parallel_config.local_ep_size, ) - from atom.config import get_current_atom_config from atom.utils.tbo.ubatching import tbo_enabled + from atom.config import get_current_atom_config handle = all2all_manager.get_handle(all_to_all_args) is_async = tbo_enabled() @@ -467,8 +434,8 @@ def _maybe_make_prepare_finalize( sync_handle = handle # IntraNode handle for prefill (sync path) if is_async: from atom.model_ops.fused_moe.mori_prepare_finalize import ( - _NUM_TBO_UBATCHES, init_mori_op, + _NUM_TBO_UBATCHES, ) tbo_mori_ops = [ @@ -791,7 +758,10 @@ def __init__(self, quant_config: LayerQuantConfig, moe: FusedMoEConfig): or gfx.startswith("gfx12") or (gfx.startswith("gfx95") and envs.ATOM_USE_TRITON_GEMM) ) - self.act_quant = MoEActivationQuant.from_model_config(moe.a_quant_dtype) + if self.use_triton: + from atom.model_ops.utils import has_triton_kernels + + assert has_triton_kernels(), "triton_kernels is not installed" def create_weights( self, @@ -825,7 +795,7 @@ def create_weights( w13_weight = atom_parameter( torch.empty( num_experts, - 2 * intermediate_size_per_partition_after_pad, # TP included + 2 * intermediate_size_per_partition_after_pad, hidden_size // 2, dtype=weight_dtype, ) @@ -865,7 +835,7 @@ def create_weights( torch.empty( num_experts, hidden_size, - intermediate_size_per_partition_after_pad // 2, # TP included + intermediate_size_per_partition_after_pad // 2, dtype=weight_dtype, ) ) @@ -897,9 +867,6 @@ def create_weights( else: layer.register_parameter("w2_bias", None) - layer.w13_swizzle_layout = None - layer.w2_swizzle_layout = None - if self.static_input_scales: w13_input_scale = atom_parameter( torch.ones(num_experts, dtype=torch.float32) @@ -926,65 +893,49 @@ def process_weights_after_loading(self, layer): return if self.use_triton: - from atom.config import get_current_atom_config - from atom.model_ops.fused_moe_triton import _swizzle_mxfp4 + import dataclasses - atom_config = get_current_atom_config() + from atom.model_ops.fused_moe_triton import _swizzle_mxfp4 - # Stash dense (pre-swizzle) shared-expert weights so the always-on - # shared expert can be evaluated by a standalone dense MXFP4 GEMM - # (gemm_a16wfp4, see Mxfp4MoEMethod._apply_shared_experts_dense) - # instead of being fused into grouped_topk routing. The shared - # experts occupy the last ``num_fused_shared_experts`` slots of the - # routed weight tensors; their raw per-expert layout - # (N, K // 2) + scale (N, K // 32) is exactly what gemm_a16wfp4 - # consumes, whereas _swizzle_mxfp4 below reorders the scales into - # the MoE-kernel-only CDNA4 layout. - n_shared = layer.num_fused_shared_experts - if n_shared > 0: - layer.shared_w13_weight = ( - layer.w13_weight.data[-n_shared:].view(torch.uint8).contiguous() - ) - layer.shared_w13_weight_scale = layer.w13_weight_scale.data[ - -n_shared: - ].contiguous() - layer.shared_w2_weight = ( - layer.w2_weight.data[-n_shared:].view(torch.uint8).contiguous() - ) - layer.shared_w2_weight_scale = layer.w2_weight_scale.data[ - -n_shared: - ].contiguous() + try: + from triton_kernels.matmul import FlexCtx, PrecisionConfig + except ImportError: + from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig - ( - w13_weight, - w13_scale, - w13_swizzle_layout, - w2_weight, - w2_scale, - w2_swizzle_layout, - ) = _swizzle_mxfp4( + w13_weight, w13_flex, w13_scale = _swizzle_mxfp4( layer.w13_weight.view(torch.uint8), layer.w13_weight_scale, + ) + w2_weight, w2_flex, w2_scale = _swizzle_mxfp4( layer.w2_weight.view(torch.uint8), layer.w2_weight_scale, - "mx4", - self.intermediate_size * 2, # N_1, - self.hidden_size, # K_1, - self.hidden_size, # N_2, - self.intermediate_size, # K_2, - atom_config.tensor_parallel_size, - act_quant=self.act_quant, ) + + _pc_field_names = {f.name for f in dataclasses.fields(PrecisionConfig)} + + def _build_precision_config(scale, flex): + kwargs = {"flex_ctx": FlexCtx(rhs_data=flex)} + if "weight_scale" in _pc_field_names: + kwargs["weight_scale"] = scale + else: + # New triton_kernels API renamed `weight_scale` → `b_mx_scale` + # and now requires the microblock size to be set explicitly. + from triton_kernels.numerics_details.mxfp import MXFP_BLOCK_SIZE + + kwargs["b_mx_scale"] = scale + kwargs["b_microblock_size"] = int(MXFP_BLOCK_SIZE) + return PrecisionConfig(**kwargs) + + self.w13_precision_config = _build_precision_config(w13_scale, w13_flex) + self.w2_precision_config = _build_precision_config(w2_scale, w2_flex) del layer.w13_weight del layer.w2_weight del layer.w13_weight_scale del layer.w2_weight_scale layer.w13_weight = w13_weight layer.w2_weight = w2_weight - layer.w13_weight_scale = w13_scale - layer.w2_weight_scale = w2_scale - layer.w13_swizzle_layout = w13_swizzle_layout - layer.w2_swizzle_layout = w2_swizzle_layout + layer.w13_weight_scale = None + layer.w2_weight_scale = None return # shuffle weight @@ -1019,19 +970,6 @@ def process_weights_after_loading(self, layer): def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: - - a1_scale = getattr(layer, "w13_input_scale", None) - a2_scale = getattr(layer, "w2_input_scale", None) - - if self.act_quant == MoEActivationQuant.FP8: - return mxfp4_w4a8_moe_quant_config( - a1_scale=a1_scale, - a2_scale=a2_scale, - w1_bias=layer.w13_bias, - w2_bias=layer.w2_bias, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - ) return mxfp4_w4a16_moe_quant_config( w1_bias=layer.w13_bias, w2_bias=layer.w2_bias, @@ -1061,8 +999,9 @@ def apply( ) -> torch.Tensor: if self.use_triton: from atom.model_ops.fused_moe_triton import ( - triton_kernel_fused_experts, triton_kernel_moe_forward, + triton_kernel_fused_experts, + fused_routing_from_topk_triton, ) # Check if the model needs custom routing that triton routing() @@ -1075,37 +1014,38 @@ def apply( ) if needs_custom_routing: - # custom routing -- set for deepseek routing n expts act, for grouped topk - n_expts_act = top_k - - # custom routing - from aiter.ops.triton.moe.moe_routing.routing import ( # grouped topk included - routing, - ) - - routing_data, gather_idx, scatter_idx = routing( - router_logits, - n_expts_act, - score_mode=scoring_func, - bias=( - e_score_correction_bias.to(torch.float32) - if e_score_correction_bias is not None - else None - ), - renorm=renormalize, - routed_scaling_factor=layer.routed_scaling_factor, + # Use ATOM's full-featured select_experts for routing, + # then triton matmul_ogs for the actual MoE computation. + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, use_grouped_topk=use_grouped_topk, - num_expert_group=num_expert_group, + top_k=top_k, + renormalize=renormalize, topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + num_fused_shared_experts=layer.num_fused_shared_experts, + routed_scaling_factor=layer.routed_scaling_factor, ) - # Routed-only gate count (no shared-expert widening). - n_expts_act = routing_data.n_expts_act + n_expts_act = topk_weights.shape[1] # Convert to triton routing data structures - num_tokens, n_expts_tot = router_logits.shape + if expert_map is not None: + # local_num_experts already includes fused shared experts + # (added at FusedMoE.__init__ line ~2056). + n_expts_tot = layer.local_num_experts + else: + n_expts_tot = router_logits.shape[-1] + if global_num_experts > 0: + n_expts_tot = global_num_experts + n_expts_tot = n_expts_tot + layer.num_fused_shared_experts - if global_num_experts > 0: - n_expts_tot = global_num_experts + routing_data, gather_idx, scatter_idx = fused_routing_from_topk_triton( + topk_weights, topk_ids, n_expts_tot, expert_map=expert_map + ) output = torch.empty_like(x) _moe_result = triton_kernel_fused_experts( @@ -1118,34 +1058,21 @@ def apply( scatter_idx, topk=n_expts_act, activation=activation, - w13_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - w13_swizzle_layout=layer.w13_swizzle_layout, - w2_swizzle_layout=layer.w2_swizzle_layout, - a13_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, + w13_precision_config=self.w13_precision_config, + w2_precision_config=self.w2_precision_config, w1_bias=layer.w13_bias, w2_bias=layer.w2_bias, swiglu_limit=getattr(layer, "swiglu_limit", 0.0), apply_router_weight_on_input=layer.apply_router_weight_on_input, global_num_experts=n_expts_tot, expert_map=expert_map, - act_quant=self.act_quant, ) - - # Always-on shared expert(s) via a standalone dense GEMM, - # added to the routed output before the TP all-reduce. - if layer.num_fused_shared_experts > 0: - _moe_result = _moe_result + self._apply_shared_experts_dense( - layer, x, activation - ) return _moe_result assert ( fused_shared_experts_scoring_func is None ), "triton kernel does not support fused shared experts func" - # Takes directly from model dtype in config.json return triton_kernel_moe_forward( x, layer.w13_weight, @@ -1154,18 +1081,13 @@ def apply( topk=top_k, renormalize=renormalize, activation=activation, - w13_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - w13_swizzle_layout=layer.w13_swizzle_layout, - w2_swizzle_layout=layer.w2_swizzle_layout, - a13_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, + w13_precision_config=self.w13_precision_config, + w2_precision_config=self.w2_precision_config, w1_bias=layer.w13_bias, w2_bias=layer.w2_bias, expert_map=expert_map, apply_router_weight_on_input=layer.apply_router_weight_on_input, global_num_experts=global_num_experts, - act_quant=self.act_quant, ) topk_weights, topk_ids = FusedMoE.select_experts( @@ -1239,73 +1161,6 @@ def apply( moe_extra_args=moe_extra_args, ) - def _apply_shared_experts_dense(self, layer, x, activation): - """Standalone dense MXFP4 GEMM for the always-on shared expert(s). - - Functionally replaces fusing the shared expert into grouped_topk - routing. The fused path appended each shared expert as an always-on - routed slot with a fixed weight of ``SHARED_SCORE`` (== 1.0 here), - placed *after* the routed renorm / ``routed_scaling_factor`` and run - through the MoE GEMM as the last expert(s) of ``w13/w2``. - - Here we instead apply the shared expert to every token via two dense - MXFP4 GEMM calls (gate_up -> SiLU-and-mul -> down) on the pre-swizzle - weight slices stashed in ``process_weights_after_loading``. The dense - GEMM is ``gemm_afp4wfp4`` (a4w4) when activations are MXFP4, otherwise - ``gemm_a16wfp4`` (a16w4), matching the routed-expert activation dtype, - and return the result so the caller adds it (weight 1.0) to the routed - output before the TP all-reduce. The shared-expert intermediate is - TP-partitioned exactly like the routed experts, so both partial outputs - reduce together. - """ - from aiter.ops.triton.fusions.fused_clamp_act_mul import fused_clamp_act_mul - from aiter.ops.triton.gemm.basic.gemm_a16wfp4 import gemm_a16wfp4 - - # The dense shared-expert GEMM only implements the SiLU activation - # path; SwiGLU models have no fused shared experts, so this assert - # documents the supported scope. - assert ( - activation != ActivationType.Swiglu - ), "dense shared-expert GEMM only supports the SiLU activation path" - - M = x.shape[0] - swiglu_limit = getattr(layer, "swiglu_limit", 0.0) - - use_a4w4 = self.act_quant == MoEActivationQuant.FP4 - if use_a4w4: - from aiter.ops.triton.gemm.basic.gemm_afp4wfp4 import gemm_afp4wfp4 - from aiter.ops.triton.moe.moe_op_gemm_a4w4 import mxfp4_quant - - def _shared_expert_gemm(act, weight, weight_scale): - if use_a4w4: - act_fp4, act_mx_scale = mxfp4_quant(act) - return gemm_afp4wfp4(act_fp4, weight, act_mx_scale, weight_scale) - return gemm_a16wfp4(act, weight, weight_scale) - - shared_out = None - for e in range(layer.num_fused_shared_experts): - gate_up = _shared_expert_gemm( - x, - layer.shared_w13_weight[e], - layer.shared_w13_weight_scale[e], - ) - half_n = gate_up.shape[-1] // 2 - intermediate = torch.empty((M, half_n), device=x.device, dtype=x.dtype) - fused_clamp_act_mul( - gate_up, - out=intermediate, - swiglu_limit=swiglu_limit, - activation="silu", - dtype_quant=None, - ) - out_e = _shared_expert_gemm( - intermediate, - layer.shared_w2_weight[e], - layer.shared_w2_weight_scale[e], - ) - shared_out = out_e if shared_out is None else shared_out + out_e - return shared_out - # Refer to CompressedTensorsW8A8Fp8MoEMethod in vllm class CompressedTensorsFp8MoEMethod(FusedMoEMethodBase): @@ -2187,7 +2042,6 @@ def __init__( ) self.layer_quant_config = layer_quant_config self.has_bias = has_bias - # Note: here we guard against accessing the TP and DP groups when # uninitialized (this happens when testing) # self.tp_size = 1 @@ -2294,18 +2148,6 @@ def __init__( self.e_score_correction_bias = e_score_correction_bias self.activation = activation - self.use_chunked = get_dp_group().world_size > 1 - - try: - a_quant_dtype = ( - config.quantization_config.get("global_quant_config", "") - .get("input_tensors", "") - .get("dtype", "") - ) - except AttributeError: - # global quant config does not exist, no activation loaded - a_quant_dtype = None - moe = FusedMoEConfig( num_experts=self.global_num_experts, experts_per_token=self.top_k, @@ -2313,7 +2155,6 @@ def __init__( num_local_experts=self.local_num_experts, moe_parallel_config=self.moe_parallel_config, in_dtype=atom_config.torch_dtype, - a_quant_dtype=a_quant_dtype, max_num_tokens=atom_config.max_num_batched_tokens, has_bias=self.has_bias, # is_act_and_mul=True, @@ -3246,6 +3087,7 @@ def forward_impl_graph( from atom.utils.tbo.ubatching import ( tbo_switch_to_compute_sync, tbo_yield_and_switch_from_compute_to_comm, + tbo_yield_and_switch_from_comm_to_compute, ) tbo_yield_and_switch_from_compute_to_comm() @@ -3295,7 +3137,7 @@ def forward_impl_graph( final_hidden_states, original_hidden_size ) if _tbo: - tbo_switch_to_compute_sync() + tbo_yield_and_switch_from_comm_to_compute() if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): # Default set to False. (May have to add shared expert outputs.) From 7a16b4089463cc9957f754206295c06b54ad2a81 Mon Sep 17 00:00:00 2001 From: JiaoliangYu Date: Wed, 17 Jun 2026 17:37:54 +0800 Subject: [PATCH 06/11] feat: env-gated MXFP8 fp8 dispatch (MORI_FP8_DISPATCH) --- .../fused_moe/mori_prepare_finalize.py | 9 ++++++-- atom/model_ops/moe.py | 23 +++++++++++++++---- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/atom/model_ops/fused_moe/mori_prepare_finalize.py b/atom/model_ops/fused_moe/mori_prepare_finalize.py index a8eaf355d0..a9e0238523 100644 --- a/atom/model_ops/fused_moe/mori_prepare_finalize.py +++ b/atom/model_ops/fused_moe/mori_prepare_finalize.py @@ -190,8 +190,13 @@ def prepare( if self.use_fp8_dispatch: from aiter import get_hip_quant - quant_func = get_hip_quant(quant_type) - a1, scale = quant_func(a1, quant_dtype=dtypes.fp8) + # [fp8-dispatch exp] MXFP8 per_1x32 + e8m0 byte scale, matching + # DSv4 expert GEMM (q_type=per_1x32). dispatch_scale is wired as + # the GEMM a1_scale in modular_kernel (a1_scale=dispatch_scale). + quant_func = get_hip_quant(QuantType.per_1x32) + a1, scale = quant_func( + a1, quant_dtype=dtypes.fp8, scale_type=dtypes.fp8_e8m0 + ) block_num, warp_per_block = self._get_dispatch_config() diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index ff4bfc12ee..da677dfa92 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -359,6 +359,15 @@ def _maybe_make_prepare_finalize( # For 1x128 quant, the scale dim for each token is hidden_dim // 128 scale_dim = 1 if quant_config.is_per_act_token else moe.hidden_dim // 128 + # [fp8-dispatch exp] env-gated MXFP8 (per_1x32) dispatch. DSv4 expert + # GEMM uses q_type=per_1x32 (block-32, e8m0 scale), so dispatch must + # quantize a1 to MXFP8 with hidden//32 scale groups. + import os as _os + + _fp8_dispatch_exp = _os.environ.get("MORI_FP8_DISPATCH", "0") == "1" + if _fp8_dispatch_exp: + scale_dim = moe.hidden_dim // 32 + # Check if quant_dtype is an FP8 type from aiter import QuantType @@ -397,7 +406,8 @@ def _maybe_make_prepare_finalize( quant_dtype=moe.in_dtype, token_hidden_size=moe.hidden_dim, scale_dim=scale_dim, - scale_type_size=torch.float32.itemsize, + # [fp8-dispatch exp] e8m0 byte scale (1B) for MXFP8 dispatch + scale_type_size=(1 if _fp8_dispatch_exp else torch.float32.itemsize), max_num_tokens_per_dp_rank=16384, # input_dtype=moe.in_dtype, input_dtype=moe.in_dtype, @@ -413,9 +423,14 @@ def _maybe_make_prepare_finalize( atom_config = get_current_atom_config() low_latency = getattr(atom_config, "enable_low_latency", False) - # We not use quant for mori now - use_fp8_dispatch = False - quant_type = None + # [fp8-dispatch exp] env-gated MXFP8 dispatch (DSv4 q_type=per_1x32). + # Default off -> baseline bf16 dispatch unchanged. + if _fp8_dispatch_exp and is_fp8: + use_fp8_dispatch = True + quant_type = QuantType.per_1x32 + else: + use_fp8_dispatch = False + quant_type = None common_args = dict( rank=all2all_manager.rank, From 95c21de18eed41854596af8b94fc927a0de33494 Mon Sep 17 00:00:00 2001 From: JiaoliangYu Date: Wed, 17 Jun 2026 17:37:54 +0800 Subject: [PATCH 07/11] feat: env-gated MXFP8 fp8 dispatch (MORI_FP8_DISPATCH), sync EP path only --- .../fused_moe/mori_prepare_finalize.py | 9 ++++-- atom/model_ops/moe.py | 28 ++++++++++++++++--- 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/atom/model_ops/fused_moe/mori_prepare_finalize.py b/atom/model_ops/fused_moe/mori_prepare_finalize.py index a8eaf355d0..a9e0238523 100644 --- a/atom/model_ops/fused_moe/mori_prepare_finalize.py +++ b/atom/model_ops/fused_moe/mori_prepare_finalize.py @@ -190,8 +190,13 @@ def prepare( if self.use_fp8_dispatch: from aiter import get_hip_quant - quant_func = get_hip_quant(quant_type) - a1, scale = quant_func(a1, quant_dtype=dtypes.fp8) + # [fp8-dispatch exp] MXFP8 per_1x32 + e8m0 byte scale, matching + # DSv4 expert GEMM (q_type=per_1x32). dispatch_scale is wired as + # the GEMM a1_scale in modular_kernel (a1_scale=dispatch_scale). + quant_func = get_hip_quant(QuantType.per_1x32) + a1, scale = quant_func( + a1, quant_dtype=dtypes.fp8, scale_type=dtypes.fp8_e8m0 + ) block_num, warp_per_block = self._get_dispatch_config() diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index bd82373aef..7f228ed624 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -402,6 +402,19 @@ def _maybe_make_prepare_finalize( torch.float8_e5m2fnuz, ) is_fp8 = quant_config.quant_dtype in fp8_dtypes + + # [fp8-dispatch exp] env-gated MXFP8 (per_1x32) dispatch, FP8 models + # only. DSv4 expert GEMM uses q_type=per_1x32 (block-32, e8m0 scale), + # so dispatch quantizes a1 to MXFP8 with hidden//32 scale groups. + # Gated on is_fp8 so non-FP8 MORI handle config stays untouched. + import os as _os + + _fp8_dispatch_exp = ( + _os.environ.get("MORI_FP8_DISPATCH", "0") == "1" and is_fp8 + ) + if _fp8_dispatch_exp: + scale_dim = moe.hidden_dim // 32 + # For FP8: enable FP8 dispatch in Mori (quantize before communication) # Note: per_Tensor quant doesn't support num_local_tokens, so we use per_Token use_fp8_dispatch = is_fp8 @@ -430,7 +443,8 @@ def _maybe_make_prepare_finalize( quant_dtype=moe.in_dtype, token_hidden_size=moe.hidden_dim, scale_dim=scale_dim, - scale_type_size=torch.float32.itemsize, + # [fp8-dispatch exp] e8m0 byte scale (1B) for MXFP8 dispatch + scale_type_size=(1 if _fp8_dispatch_exp else torch.float32.itemsize), max_num_tokens_per_dp_rank=16384, # input_dtype=moe.in_dtype, input_dtype=moe.in_dtype, @@ -446,9 +460,15 @@ def _maybe_make_prepare_finalize( atom_config = get_current_atom_config() low_latency = getattr(atom_config, "enable_low_latency", False) - # We not use quant for mori now - use_fp8_dispatch = False - quant_type = None + # [fp8-dispatch exp] env-gated MXFP8 dispatch (DSv4 q_type=per_1x32). + # Default off -> baseline bf16 dispatch unchanged. _fp8_dispatch_exp + # already includes the is_fp8 check. + if _fp8_dispatch_exp: + use_fp8_dispatch = True + quant_type = QuantType.per_1x32 + else: + use_fp8_dispatch = False + quant_type = None common_args = dict( rank=all2all_manager.rank, From 2657110bd50394b0777a8c1b50c94c31add28bd5 Mon Sep 17 00:00:00 2001 From: JiaoliangYu Date: Wed, 17 Jun 2026 19:12:54 +0800 Subject: [PATCH 08/11] exp: env-gated balanced round-robin routing (MORI_FORCE_BALANCED_ROUTING) --- atom/model_ops/moe.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index 7f228ed624..19095fe279 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -3257,6 +3257,20 @@ def select_experts( f"Unsupported scoring function for non-grouped topk: {scoring_func}" ) + # [balance exp] env-gated: force perfectly balanced round-robin routing + # over routed experts (ignore accuracy) to isolate load-imbalance from + # launch/sync overhead in the EP dispatch/combine path. + import os as _os + + if _os.environ.get("MORI_FORCE_BALANCED_ROUTING", "0") == "1": + _nt, _k = topk_ids.shape + _ne = router_logits.shape[-1] + topk_ids = ( + (torch.arange(_nt * _k, device=topk_ids.device) % _ne) + .reshape(_nt, _k) + .to(topk_ids.dtype) + ) + return topk_weights, topk_ids def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): From c5336abd308932a941c940b51b1b17113c8fe932 Mon Sep 17 00:00:00 2001 From: JiaoliangYu Date: Thu, 18 Jun 2026 10:48:29 +0800 Subject: [PATCH 09/11] exp: fp8 cast-dispatch (MORI_FP8_DISPATCH_CAST) - measure dispatch kernel, cast back to bf16 for GEMM --- atom/model_ops/fused_moe/mori_prepare_finalize.py | 12 ++++++++++++ atom/model_ops/moe.py | 9 +++++++++ 2 files changed, 21 insertions(+) diff --git a/atom/model_ops/fused_moe/mori_prepare_finalize.py b/atom/model_ops/fused_moe/mori_prepare_finalize.py index a9e0238523..4f5b94019c 100644 --- a/atom/model_ops/fused_moe/mori_prepare_finalize.py +++ b/atom/model_ops/fused_moe/mori_prepare_finalize.py @@ -113,6 +113,7 @@ def __init__( is_async: bool = False, tbo_mori_ops: list | None = None, low_latency: bool = False, + fp8_cast_dispatch: bool = False, ): if not MORI_AVAILABLE: raise ImportError( @@ -129,6 +130,7 @@ def __init__( self.quant_dtype = quant_dtype self._is_async = is_async self._low_latency = low_latency + self.fp8_cast_dispatch = fp8_cast_dispatch @property def activation_format(self) -> mk.FusedMoEActivationFormat: @@ -197,6 +199,11 @@ def prepare( a1, scale = quant_func( a1, quant_dtype=dtypes.fp8, scale_type=dtypes.fp8_e8m0 ) + elif self.fp8_cast_dispatch: + # [fp8-cast exp] direct bf16->fp8 cast (no scale) so dispatch runs + # the fp8 kernel (half bytes); cast back to bf16 after dispatch so + # the GEMM self-quantizes as baseline. Measures dispatch speedup. + a1 = a1.to(dtypes.fp8) block_num, warp_per_block = self._get_dispatch_config() @@ -210,6 +217,11 @@ def prepare( a1, topk_weights, scale, topk_ids, block_num, warp_per_block ) + if self.fp8_cast_dispatch: + # dispatch_a1 is fp8 (mori out uses input.dtype); upcast for GEMM. + dispatch_a1 = dispatch_a1.to(torch.bfloat16) + dispatch_scale = None + expert_tokens_meta = mk.ExpertTokensMetadata( expert_num_tokens=dispatch_recv_token_num, expert_num_tokens_cpu=None ) diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index 19095fe279..877fc9135c 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -415,6 +415,14 @@ def _maybe_make_prepare_finalize( if _fp8_dispatch_exp: scale_dim = moe.hidden_dim // 32 + # [fp8-cast exp] measurement-only: cast a1 bf16->fp8 before dispatch + # (no scale), dispatch via fp8 kernel, cast back to bf16 after so the + # GEMM self-quantizes as usual (no scale plumbing, no crash). Lets us + # measure the real fp8 dispatch kernel speedup. Accuracy is garbage. + _fp8_cast_dispatch = ( + _os.environ.get("MORI_FP8_DISPATCH_CAST", "0") == "1" + ) + # For FP8: enable FP8 dispatch in Mori (quantize before communication) # Note: per_Tensor quant doesn't support num_local_tokens, so we use per_Token use_fp8_dispatch = is_fp8 @@ -509,6 +517,7 @@ def _maybe_make_prepare_finalize( is_async=is_async, tbo_mori_ops=tbo_mori_ops, low_latency=low_latency, + fp8_cast_dispatch=_fp8_cast_dispatch, ) return prepare_finalize From 666366c6426cd1475df43cb5e0c91bc264dffabc Mon Sep 17 00:00:00 2001 From: JiaoliangYu Date: Thu, 18 Jun 2026 11:44:53 +0800 Subject: [PATCH 10/11] exp: gate fp8 dispatch on env only (DSv4 quant_dtype=None so is_fp8 False) --- atom/model_ops/moe.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index 877fc9135c..7aa9ab2979 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -409,9 +409,10 @@ def _maybe_make_prepare_finalize( # Gated on is_fp8 so non-FP8 MORI handle config stays untouched. import os as _os - _fp8_dispatch_exp = ( - _os.environ.get("MORI_FP8_DISPATCH", "0") == "1" and is_fp8 - ) + # NOTE: drop the is_fp8 guard. DSv4 MXFP8 activation quant is owned by + # the aiter GEMM, so quant_config.quant_dtype is None here (is_fp8=False). + # We know DSv4 is per_1x32 MXFP8, so gate on the env flag alone. + _fp8_dispatch_exp = _os.environ.get("MORI_FP8_DISPATCH", "0") == "1" if _fp8_dispatch_exp: scale_dim = moe.hidden_dim // 32 From 4d302b81e9e0364363127f61800e7bddbb148c51 Mon Sep 17 00:00:00 2001 From: JiaoliangYu Date: Thu, 18 Jun 2026 13:19:47 +0800 Subject: [PATCH 11/11] exp: adapt async/TBO fp8 dispatch to MXFP8 per_1x32+e8m0 (EP+TBO support) --- .../fused_moe/mori_prepare_finalize.py | 101 +++++++--- atom/model_ops/moe.py | 188 ++++++++++-------- 2 files changed, 171 insertions(+), 118 deletions(-) diff --git a/atom/model_ops/fused_moe/mori_prepare_finalize.py b/atom/model_ops/fused_moe/mori_prepare_finalize.py index 4f5b94019c..a0fc889a0b 100644 --- a/atom/model_ops/fused_moe/mori_prepare_finalize.py +++ b/atom/model_ops/fused_moe/mori_prepare_finalize.py @@ -40,6 +40,8 @@ def init_mori_op( max_token_type_size: int, low_latency: bool = False, instance_id: int = 0, + scale_type_size: int = torch.float32.itemsize, + quant_type: str = "none", ) -> Any: """ Create a mori op instance. @@ -76,7 +78,7 @@ def init_mori_op( data_type=data_type, hidden_dim=hidden_dim, scale_dim=scale_dim, - scale_type_size=torch.float32.itemsize, + scale_type_size=scale_type_size, max_token_type_size=max_token_type_size, max_num_inp_token_per_rank=max_num_inp_token_per_rank, num_experts_per_rank=num_local_experts, @@ -86,6 +88,7 @@ def init_mori_op( kernel_type=kernel_type, gpu_per_node=gpu_per_node, rdma_block_num=rdma_block_num, + quant_type=quant_type, **({"num_qp_per_pe": 2} if low_latency else {}), ) mori_op = mori.ops.EpDispatchCombineOp(mori_config) @@ -108,12 +111,14 @@ def __init__( max_tokens_per_rank: int, num_dispatchers: int, use_fp8_dispatch: bool = False, + dispatch_quant_dtype: torch.dtype | None = None, quant_type=None, quant_dtype: torch.dtype = None, is_async: bool = False, tbo_mori_ops: list | None = None, low_latency: bool = False, - fp8_cast_dispatch: bool = False, + fp8_dispatch_decode_only: bool = False, + dispatch_quant_decode_only: bool = False, ): if not MORI_AVAILABLE: raise ImportError( @@ -125,12 +130,22 @@ def __init__( self._tbo_mori_ops = tbo_mori_ops # per-ubatch ops for TBO (IntraNode) self.num_dispatchers_ = num_dispatchers self.max_tokens_per_rank = max_tokens_per_rank - self.use_fp8_dispatch = use_fp8_dispatch + # dispatch quantization target. We keep `use_fp8_dispatch` for backward + # compatibility and map it to the generalized quantized-dispatch path. + if dispatch_quant_dtype is None and use_fp8_dispatch: + dispatch_quant_dtype = dtypes.fp8 + self.dispatch_quant_dtype = dispatch_quant_dtype self.quant_type = quant_type self.quant_dtype = quant_dtype self._is_async = is_async self._low_latency = low_latency - self.fp8_cast_dispatch = fp8_cast_dispatch + # Phase gating for quantized dispatch (fp8/fp4). Keep the legacy + # fp8-only flag for compatibility. + self.dispatch_quant_decode_only = ( + dispatch_quant_decode_only or fp8_dispatch_decode_only + ) + self.use_fp8_dispatch = self.dispatch_quant_dtype == dtypes.fp8 + self.fp8_dispatch_decode_only = self.dispatch_quant_decode_only @property def activation_format(self) -> mk.FusedMoEActivationFormat: @@ -162,6 +177,40 @@ def _get_dispatch_config(self): return 128, 16 return 64, 4 + def _dispatch_quant_dtype_now(self) -> torch.dtype | None: + """Dispatch quant dtype decision for the current forward (phase-aware).""" + if self.dispatch_quant_dtype is None: + return None + if self.dispatch_quant_decode_only and get_forward_context().context.is_prefill: + return None + return self.dispatch_quant_dtype + + def _quantize_dispatch_input( + self, a1: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor | None]: + dispatch_quant_dtype = self._dispatch_quant_dtype_now() + if dispatch_quant_dtype is None: + return a1, None + + from aiter import get_hip_quant + + quant_func = get_hip_quant(QuantType.per_1x32) + if dispatch_quant_dtype == dtypes.fp8: + # MXFP8 per_1x32 + e8m0 scale. + return quant_func( + a1, + quant_dtype=dtypes.fp8, + scale_type=dtypes.fp8_e8m0, + ) + if dispatch_quant_dtype == dtypes.fp4x2: + # MXFP4 per_1x32 + e8m0 scale. + return quant_func( + a1, + quant_dtype=dtypes.fp4x2, + scale_type=dtypes.fp8_e8m0, + ) + raise ValueError(f"Unsupported dispatch quant dtype: {dispatch_quant_dtype}") + # ---- Synchronous (non-TBO) path ---- def prepare( @@ -188,22 +237,7 @@ def prepare( assert ( not apply_router_weight_on_input ), "mori does not support apply_router_weight_on_input=True now." - scale = None - if self.use_fp8_dispatch: - from aiter import get_hip_quant - - # [fp8-dispatch exp] MXFP8 per_1x32 + e8m0 byte scale, matching - # DSv4 expert GEMM (q_type=per_1x32). dispatch_scale is wired as - # the GEMM a1_scale in modular_kernel (a1_scale=dispatch_scale). - quant_func = get_hip_quant(QuantType.per_1x32) - a1, scale = quant_func( - a1, quant_dtype=dtypes.fp8, scale_type=dtypes.fp8_e8m0 - ) - elif self.fp8_cast_dispatch: - # [fp8-cast exp] direct bf16->fp8 cast (no scale) so dispatch runs - # the fp8 kernel (half bytes); cast back to bf16 after dispatch so - # the GEMM self-quantizes as baseline. Measures dispatch speedup. - a1 = a1.to(dtypes.fp8) + a1, scale = self._quantize_dispatch_input(a1) block_num, warp_per_block = self._get_dispatch_config() @@ -217,11 +251,6 @@ def prepare( a1, topk_weights, scale, topk_ids, block_num, warp_per_block ) - if self.fp8_cast_dispatch: - # dispatch_a1 is fp8 (mori out uses input.dtype); upcast for GEMM. - dispatch_a1 = dispatch_a1.to(torch.bfloat16) - dispatch_scale = None - expert_tokens_meta = mk.ExpertTokensMetadata( expert_num_tokens=dispatch_recv_token_num, expert_num_tokens_cpu=None ) @@ -271,19 +300,25 @@ def prepare_async( ), "mori does not support apply_router_weight_on_input=True now." scale = None - if self.use_fp8_dispatch: - from aiter import get_hip_quant - + dispatch_quant_dtype = self._dispatch_quant_dtype_now() + if dispatch_quant_dtype is not None: num_tokens = a1.shape[0] if num_tokens > 0: - quant_func = get_hip_quant(QuantType.per_1x128) - a1, scale = quant_func(a1, quant_dtype=dtypes.fp8) + a1, scale = self._quantize_dispatch_input(a1) else: hidden_size = a1.shape[1] if a1.dim() > 1 else 0 - a1 = torch.empty(a1.shape, dtype=dtypes.fp8, device=a1.device) + if dispatch_quant_dtype == dtypes.fp4x2: + packed_hidden_size = hidden_size // 2 + a1 = torch.empty( + (0, packed_hidden_size), + dtype=dtypes.fp4x2, + device=a1.device, + ) + else: + a1 = torch.empty(a1.shape, dtype=dispatch_quant_dtype, device=a1.device) scale = torch.empty( - (0, hidden_size // 128), - dtype=torch.float32, + (0, hidden_size // 32), + dtype=dtypes.fp8_e8m0, device=a1.device, ) diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index 7aa9ab2979..b8e33c5319 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -319,6 +319,25 @@ def dp_gather_hidden_and_router( return hidden_states, router_logits, original_hidden_size, None +def run_dp_collective_with_tbo_overlap(fn): + """Run a DP collective, yielding to TBO comm stream when overlap is active.""" + from atom.utils.tbo.ubatching import tbo_active + + if not tbo_active(): + return fn() + + from atom.utils.tbo.ubatching import ( + tbo_switch_to_compute_sync, + tbo_yield_and_switch_from_compute_to_comm, + ) + + tbo_yield_and_switch_from_compute_to_comm() + try: + return fn() + finally: + tbo_switch_to_compute_sync() + + @torch_compile_guard() def get_max_tokens_across_dispatchers(input: torch.Tensor) -> int: return input.item() @@ -392,74 +411,90 @@ def _maybe_make_prepare_finalize( # For 1x128 quant, the scale dim for each token is hidden_dim // 128 scale_dim = 1 if quant_config.is_per_act_token else moe.hidden_dim // 128 - # Check if quant_dtype is an FP8 type from aiter import QuantType - fp8_dtypes = ( - torch.float8_e4m3fn, - torch.float8_e4m3fnuz, - torch.float8_e5m2, - torch.float8_e5m2fnuz, - ) - is_fp8 = quant_config.quant_dtype in fp8_dtypes - - # [fp8-dispatch exp] env-gated MXFP8 (per_1x32) dispatch, FP8 models - # only. DSv4 expert GEMM uses q_type=per_1x32 (block-32, e8m0 scale), - # so dispatch quantizes a1 to MXFP8 with hidden//32 scale groups. - # Gated on is_fp8 so non-FP8 MORI handle config stays untouched. import os as _os - # NOTE: drop the is_fp8 guard. DSv4 MXFP8 activation quant is owned by - # the aiter GEMM, so quant_config.quant_dtype is None here (is_fp8=False). - # We know DSv4 is per_1x32 MXFP8, so gate on the env flag alone. - _fp8_dispatch_exp = _os.environ.get("MORI_FP8_DISPATCH", "0") == "1" - if _fp8_dispatch_exp: - scale_dim = moe.hidden_dim // 32 + dispatch_quant_dtype: torch.dtype | None = None + dispatch_dtype_env = _os.environ.get("MORI_DISPATCH_DTYPE", "auto").strip().lower() + if dispatch_dtype_env in ("auto", ""): + model_dispatch_quant = MoEActivationQuant.from_model_config( + moe.a_quant_dtype if isinstance(moe.a_quant_dtype, str) else None + ) + if model_dispatch_quant == MoEActivationQuant.FP8: + dispatch_quant_dtype = dtypes.fp8 + elif model_dispatch_quant == MoEActivationQuant.FP4: + dispatch_quant_dtype = dtypes.fp4x2 + elif dispatch_dtype_env in ("none", "off", "bf16"): + dispatch_quant_dtype = None + elif dispatch_dtype_env == "fp8": + dispatch_quant_dtype = dtypes.fp8 + elif dispatch_dtype_env == "fp4": + dispatch_quant_dtype = dtypes.fp4x2 + else: + logger.warning( + "Invalid MORI_DISPATCH_DTYPE=%s. Supported: auto|none|bf16|fp8|fp4. Falling back to auto.", + dispatch_dtype_env, + ) + model_dispatch_quant = MoEActivationQuant.from_model_config( + moe.a_quant_dtype if isinstance(moe.a_quant_dtype, str) else None + ) + if model_dispatch_quant == MoEActivationQuant.FP8: + dispatch_quant_dtype = dtypes.fp8 + elif model_dispatch_quant == MoEActivationQuant.FP4: + dispatch_quant_dtype = dtypes.fp4x2 + + # Backward compatibility for the old fp8 env knob. + if _os.environ.get("MORI_FP8_DISPATCH", "0") == "1": + dispatch_quant_dtype = dtypes.fp8 + + if dispatch_quant_dtype is not None and moe.hidden_dim % 32 != 0: + logger.warning( + "Disable quantized MORI dispatch because hidden_dim=%d is not divisible by 32.", + moe.hidden_dim, + ) + dispatch_quant_dtype = None - # [fp8-cast exp] measurement-only: cast a1 bf16->fp8 before dispatch - # (no scale), dispatch via fp8 kernel, cast back to bf16 after so the - # GEMM self-quantizes as usual (no scale plumbing, no crash). Lets us - # measure the real fp8 dispatch kernel speedup. Accuracy is garbage. - _fp8_cast_dispatch = ( - _os.environ.get("MORI_FP8_DISPATCH_CAST", "0") == "1" - ) - - # For FP8: enable FP8 dispatch in Mori (quantize before communication) - # Note: per_Tensor quant doesn't support num_local_tokens, so we use per_Token - use_fp8_dispatch = is_fp8 - quant_type = None - if use_fp8_dispatch: - if quant_config.is_block_quantized: - quant_type = QuantType.per_1x128 - elif quant_config.is_per_act_token: - quant_type = QuantType.per_Token - - # For FP8: use FP8 dtype for communication - # For FP4/no quant: use bfloat16 - # mori_dtype = ( - # quant_config.quant_dtype - # if is_fp8 and quant_type is not None - # else torch.bfloat16 - # ) - # mori_dtype = torch.bfloat16 + # phase gating for quantized dispatch. Keep the legacy fp8-only env. + _dispatch_quant_decode_only = ( + _os.environ.get( + "MORI_DISPATCH_DECODE_ONLY", + _os.environ.get("MORI_FP8_DISPATCH_DECODE_ONLY", "0"), + ) + == "1" + ) + + if dispatch_quant_dtype is not None: + scale_dim = moe.hidden_dim // 32 + scale_type_size = 1 # e8m0 scale + else: + scale_type_size = torch.float32.itemsize + + # Combine quantization mode is provided by the standardized MORI + # dispatch/combine API. + combine_quant_type = _os.environ.get( + "MORI_COMBINE_QUANT_TYPE", "none" + ).strip().lower() + if combine_quant_type not in {"none", "fp8_direct_cast", "fp8_blockwise"}: + logger.warning( + "Invalid MORI_COMBINE_QUANT_TYPE=%s. Supported: none|fp8_direct_cast|fp8_blockwise. Falling back to none.", + combine_quant_type, + ) + combine_quant_type = "none" all_to_all_args = dict( rank=all2all_manager.rank, num_ep_ranks=all2all_manager.world_size, - # quant_dtype=mori_dtype, - # We now use bfloat16 for mori - # TODO: To support quant quant_dtype=moe.in_dtype, token_hidden_size=moe.hidden_dim, scale_dim=scale_dim, - # [fp8-dispatch exp] e8m0 byte scale (1B) for MXFP8 dispatch - scale_type_size=(1 if _fp8_dispatch_exp else torch.float32.itemsize), + scale_type_size=scale_type_size, max_num_tokens_per_dp_rank=16384, - # input_dtype=moe.in_dtype, input_dtype=moe.in_dtype, num_local_experts=moe.num_experts // all2all_manager.world_size, num_experts_per_token=moe.experts_per_token, gpu_per_node=moe.moe_parallel_config.local_ep_size, + quant_type=combine_quant_type, ) from atom.config import get_current_atom_config from atom.utils.tbo.ubatching import tbo_enabled @@ -469,15 +504,8 @@ def _maybe_make_prepare_finalize( atom_config = get_current_atom_config() low_latency = getattr(atom_config, "enable_low_latency", False) - # [fp8-dispatch exp] env-gated MXFP8 dispatch (DSv4 q_type=per_1x32). - # Default off -> baseline bf16 dispatch unchanged. _fp8_dispatch_exp - # already includes the is_fp8 check. - if _fp8_dispatch_exp: - use_fp8_dispatch = True - quant_type = QuantType.per_1x32 - else: - use_fp8_dispatch = False - quant_type = None + use_fp8_dispatch = dispatch_quant_dtype == dtypes.fp8 + quant_type = QuantType.per_1x32 if use_fp8_dispatch else None common_args = dict( rank=all2all_manager.rank, @@ -490,6 +518,8 @@ def _maybe_make_prepare_finalize( gpu_per_node=moe.moe_parallel_config.local_ep_size, data_type_itemsize=moe.in_dtype.itemsize, max_token_type_size=moe.in_dtype.itemsize, + scale_type_size=scale_type_size, + quant_type=combine_quant_type, ) tbo_mori_ops = None @@ -514,11 +544,13 @@ def _maybe_make_prepare_finalize( max_tokens_per_rank=moe.max_num_tokens, num_dispatchers=all2all_manager.world_size, use_fp8_dispatch=use_fp8_dispatch, + dispatch_quant_dtype=dispatch_quant_dtype, quant_type=quant_type, is_async=is_async, tbo_mori_ops=tbo_mori_ops, low_latency=low_latency, - fp8_cast_dispatch=_fp8_cast_dispatch, + fp8_dispatch_decode_only=_dispatch_quant_decode_only, + dispatch_quant_decode_only=_dispatch_quant_decode_only, ) return prepare_finalize @@ -3308,29 +3340,17 @@ def forward_impl_graph( dp_group = get_dp_group() dp_eager_mode = not ctx.context.dp_uniform_decode - from atom.utils.tbo.ubatching import tbo_active - - _tbo = tbo_active() - if _tbo: - from atom.utils.tbo.ubatching import ( - tbo_switch_to_compute_sync, - tbo_yield_and_switch_from_compute_to_comm, - ) - - tbo_yield_and_switch_from_compute_to_comm() - ( hidden_states, router_logits, original_hidden_size, sizes, - ) = dp_gather_hidden_and_router( - hidden_states, router_logits, dp_eager_mode, ctx, dp_group + ) = run_dp_collective_with_tbo_overlap( + lambda: dp_gather_hidden_and_router( + hidden_states, router_logits, dp_eager_mode, ctx, dp_group + ) ) - if _tbo: - tbo_switch_to_compute_sync() - # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, @@ -3353,18 +3373,16 @@ def forward_impl_graph( # Use reduce_scatter when DP > 1 but not using mori all2all kernels if use_dp_gather_scatter: - if _tbo: - tbo_yield_and_switch_from_compute_to_comm() if dp_eager_mode: - final_hidden_states = reduce_scatterv( - final_hidden_states, sizes, dp_group + final_hidden_states = run_dp_collective_with_tbo_overlap( + lambda: reduce_scatterv(final_hidden_states, sizes, dp_group) ) else: - final_hidden_states = reduce_scatter_with_unpadding( - final_hidden_states, original_hidden_size + final_hidden_states = run_dp_collective_with_tbo_overlap( + lambda: reduce_scatter_with_unpadding( + final_hidden_states, original_hidden_size + ) ) - if _tbo: - tbo_switch_to_compute_sync() if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): # Default set to False. (May have to add shared expert outputs.)