-
-
Notifications
You must be signed in to change notification settings - Fork 11.3k
[Hardware][AMD][Kernel] mori all2all backend integration #26013
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
cd167ad
77c8a56
ea17a69
a00b38b
07095ff
26ec16d
5c997be
9849fa6
f4fb63a
c4bbc1a
6dabf63
19e7d40
a1de125
d2f65f6
fb01286
739b489
770676e
ff08bc8
059f29a
221edae
fc55d72
7d52023
e357840
6d8ef43
eace564
73093c5
b9b9a9b
d3c6ce0
32482ee
73a17d5
eda8c8e
797d819
ddd3563
272205e
4abf225
1dbff2c
665a631
753a506
096f938
8badbd6
25f8d59
a476913
4c4306e
3adc79e
d4a1529
756898a
e9b8624
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -41,6 +41,7 @@ th { | |
| | flashinfer<sup>4</sup> | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferCutlassMoEPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize.FlashInferCutlassMoEPrepareAndFinalize] | | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please remove this line (duplicate of prior line)
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not part of this PR, and it's not related. |
||
| | MoEPrepareAndFinalizeNoEP<sup>5</sup> | standard | fp8,int8 | G,A,T | N | Y | [`MoEPrepareAndFinalizeNoEP`][vllm.model_executor.layers.fused_moe.prepare_finalize.MoEPrepareAndFinalizeNoEP] | | ||
| | BatchedPrepareAndFinalize<sup>5</sup> | batched | fp8,int8 | G,A,T | N | Y | [`BatchedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.fused_batched_moe.BatchedPrepareAndFinalize] | | ||
| | MoriPrepareAndFinalize<sup>7</sup> | standard | fp8<sup>8</sup> | G(128),A,T<sup>8</sup> | N | Y | [`MoriPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.mori_prepare_finalize.MoriPrepareAndFinalize] | | ||
|
|
||
| !!! info "Table key" | ||
| 1. All types: mxfp4, nvfp4, int4, int8, fp8 | ||
|
|
@@ -49,6 +50,8 @@ th { | |
| 4. Controlled by different env vars (`VLLM_FLASHINFER_MOE_BACKEND` "throughput" or "latency") | ||
| 5. This is a no-op dispatcher that can be used to pair with any modular experts to produce a modular kernel that runs w/o dispatch or combine. These cannot be selected via environment variable. These are generally use for testing or adapting an expert subclass to the `fused_experts` API. | ||
| 6. This depends on the experts implementation. | ||
| 7. Currently, MoRI supports low-latency mode only. | ||
| 8. This depends on the experts implementation, currently mori supports aiter. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this to explain or a direct answer to
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, we integrated mori on aiter moe only. |
||
|
|
||
| --- | ||
|
|
||
|
|
@@ -117,4 +120,5 @@ The following table shows "families" of modular kernels that are intended to wor | |
| |----------------------------------|------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------| | ||
| | deepep_high_throughput | `DeepEPHTPrepareAndFinalize` | `DeepGemmExperts`,</br>`TritonExperts`,</br>`TritonOrDeepGemmExperts`,</br>`CutlassExpertsFp8`, </br>`MarlinExperts` | | ||
| | deepep_low_latency,</br>pplx | `DeepEPLLPrepareAndFinalize`,</br>`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,</br>`BatchedTritonExperts`,</br>`BatchedTritonOrDeepGemmExperts`,</br>`CutlassBatchedExpertsFp8`,</br>`BatchedMarlinExperts`| | ||
| | flashinfer | `FlashInferCutlassMoEPrepareAndFinalize` | `FlashInferExperts` | | ||
| | flashinfer | `FlashInferCutlassMoEPrepareAndFinalize` | `FlashInferExperts` | | ||
| | mori | `MoriPrepareAndFinalize` | `AiterMoriExperts` | | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,11 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| import json | ||
| import os | ||
| from pathlib import Path | ||
| from typing import Any | ||
|
|
||
| import psutil | ||
| import torch | ||
| import torch.distributed as dist | ||
|
|
||
|
|
@@ -10,7 +14,7 @@ | |
| from vllm.forward_context import get_forward_context | ||
| from vllm.logger import init_logger | ||
| from vllm.utils.flashinfer import has_flashinfer_all2all | ||
| from vllm.utils.import_utils import has_deep_ep, has_pplx | ||
| from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx | ||
|
|
||
| from .base_device_communicator import All2AllManagerBase, Cache | ||
|
|
||
|
|
@@ -488,3 +492,342 @@ def cleanup(self): | |
| self.prepare_workspace_tensor = None | ||
| self.mapping = None | ||
| self.initialized = False | ||
|
|
||
|
|
||
| class MoriAll2AllManager(All2AllManagerBase): | ||
| """ | ||
| All2All communication based on mori kernels. | ||
| """ | ||
|
|
||
| def __init__(self, cpu_group): | ||
| assert has_mori(), "Please install mori from ROCm/mori github." | ||
|
|
||
| super().__init__(cpu_group) | ||
| self.handle_cache = Cache() | ||
| self.config = None | ||
| self._shmem_initialized = False | ||
|
|
||
| self.json_config = None | ||
| config_path = envs.VLLM_MORI_CONFIG_PATH | ||
| if config_path: | ||
| self.json_config = self._load_mori_config_from_json(config_path) | ||
|
|
||
| # Delay mori shmem initialization until first use | ||
| logger.debug("[rank %s] MoriAll2AllManager created", self.rank) | ||
|
|
||
| def _ensure_shmem_initialized(self): | ||
| """Initialize mori's shared memory system lazily""" | ||
| if self._shmem_initialized: | ||
| return | ||
|
|
||
| import torch.distributed as dist | ||
| from mori.shmem import shmem_torch_process_group_init | ||
|
|
||
| try: | ||
| # Check if we have a valid backend | ||
| backend = dist.get_backend() | ||
| if backend is None: | ||
| raise RuntimeError("No valid distributed backend found") | ||
|
|
||
| logger.debug( | ||
| "[rank %s] PyTorch distributed ready with backend: %s", | ||
| self.rank, | ||
| backend, | ||
| ) | ||
|
|
||
| assert self.cpu_group is not None, "No CPU group is given to mori" | ||
| ppid = psutil.Process(os.getpid()).ppid() | ||
| group_name = f"mori_shmem_group_{ppid}" | ||
|
|
||
| try: | ||
| import torch._C._distributed_c10d as c10d | ||
|
|
||
| # Register the process group | ||
| c10d._register_process_group(group_name, self.cpu_group) | ||
| logger.debug( | ||
| "[rank %s] Registered proc group %s", self.rank, group_name | ||
| ) | ||
|
|
||
| # Initialize mori shmem with the registered group | ||
| shmem_torch_process_group_init(group_name) | ||
| logger.debug("[rank %s] torch proc group shmem init success", self.rank) | ||
| self._shmem_initialized = True | ||
| return | ||
|
|
||
| except Exception as torch_error: | ||
| raise RuntimeError( | ||
| "torch process group initialization failed" | ||
| ) from torch_error | ||
|
|
||
| except Exception as e: | ||
| raise RuntimeError("mori shmem initialization failed") from e | ||
|
|
||
| def _load_mori_config_from_json(self, json_path: str) -> dict | None: | ||
| """ | ||
| Load mori configuration parameters from JSON file. | ||
|
|
||
| Supports both flat and hierarchical schema: | ||
|
|
||
| Flat schema: | ||
| { | ||
| "warp_num_per_block": 8, | ||
| "block_num": 80, | ||
| } | ||
|
|
||
| Hierarchical schema (dispatch/combine specific): | ||
| { | ||
| "global": { | ||
| "warp_num_per_block": 8, | ||
| "block_num": 80, | ||
| }, | ||
| "dispatch": { | ||
| "warp_num_per_block": 16, | ||
| "block_num": 160 | ||
| }, | ||
| "combine": { | ||
| "warp_num_per_block": 4, | ||
| "block_num": 40 | ||
| } | ||
| } | ||
|
|
||
| Args: | ||
| json_path: Path to JSON configuration file | ||
|
|
||
| Returns: | ||
| Dictionary of configuration parameters, or None if file doesn't exist | ||
|
|
||
| Raises: | ||
| ValueError: If JSON is invalid or contains unsupported parameters | ||
| """ | ||
| if not json_path: | ||
| return None | ||
|
|
||
| json_file = Path(json_path) | ||
| if not json_file.exists(): | ||
| logger.warning( | ||
| "[rank %d] Mori config file not found: %s", self.rank, json_path | ||
| ) | ||
| return None | ||
|
|
||
| try: | ||
| with open(json_file) as f: | ||
| config = json.load(f) | ||
|
|
||
| # Valid parameter keys | ||
| valid_param_keys = { | ||
| "warp_num_per_block", | ||
| "block_num", | ||
| } | ||
|
|
||
| is_hierarchical = any( | ||
| key in config for key in ["global", "dispatch", "combine"] | ||
| ) | ||
|
|
||
| if is_hierarchical: | ||
| valid_top_keys = {"global", "dispatch", "combine"} | ||
| invalid_keys = set(config.keys()) - valid_top_keys | ||
| if invalid_keys: | ||
| raise ValueError( | ||
| f"Invalid top-level keys: {invalid_keys}. " | ||
| f"Valid keys: {valid_top_keys}" | ||
| ) | ||
|
|
||
| # Validate each section | ||
| for section in ["global", "dispatch", "combine"]: | ||
| if section in config: | ||
| section_config = config[section] | ||
| if not isinstance(section_config, dict): | ||
| raise ValueError(f"'{section}' must be a dictionary") | ||
|
|
||
| invalid_keys = set(section_config.keys()) - valid_param_keys | ||
| if invalid_keys: | ||
| raise ValueError( | ||
| f"Invalid keys in '{section}': {invalid_keys}. " | ||
| f"Valid keys: {valid_param_keys}" | ||
| ) | ||
| else: | ||
| invalid_keys = set(config.keys()) - valid_param_keys | ||
| if invalid_keys: | ||
| raise ValueError( | ||
| f"Invalid config keys: {invalid_keys}. " | ||
| f"Valid keys: {valid_param_keys}" | ||
| ) | ||
|
|
||
| return config | ||
|
|
||
| except json.JSONDecodeError as e: | ||
| raise ValueError(f"Invalid JSON in mori config file {json_path}") from e | ||
| except Exception as e: | ||
| raise ValueError(f"Error loading mori config from {json_path}") from e | ||
|
|
||
| def _make_mori_config( | ||
| self, | ||
| max_num_tokens: int, | ||
| num_local_experts: int, | ||
| experts_per_token: int, | ||
| hidden_dim: int, | ||
| scale_dim: int, | ||
| scale_type_size: int, | ||
| data_type: torch.dtype = torch.bfloat16, | ||
| quant_dtype: torch.dtype | None = None, | ||
| ): | ||
| """ | ||
| Create mori EpDispatchCombineConfig. | ||
|
|
||
| Args: | ||
| max_num_tokens: Maximum number of tokens per DP rank | ||
| num_local_experts: Number of local experts | ||
| experts_per_token: Number of experts per token (topk) | ||
| hidden_dim: Hidden dimension size | ||
| scale_dim: Scale dimension for quantization | ||
| scale_type_size: Scale type size for quantization | ||
| data_type: Tensor data type | ||
| quant_dtype: Quantization data type (optional) | ||
| """ | ||
| from mori.ops import EpDispatchCombineConfig | ||
| from mori.ops.dispatch_combine import EpDispatchCombineKernelType | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. adding check for data_type and quant_dtype before proceed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. quant_dtype check is added. |
||
| from vllm.platforms import current_platform | ||
|
|
||
| assert quant_dtype is None or quant_dtype == current_platform.fp8_dtype() | ||
|
|
||
| # Default values (can be overridden by JSON) | ||
| warp_num_per_block = 8 | ||
| block_num = 80 | ||
|
|
||
| # Override with JSON config if provided | ||
| if self.json_config is not None: | ||
| is_hierarchical = any( | ||
| key in self.json_config for key in ["global", "dispatch", "combine"] | ||
| ) | ||
|
|
||
| global_config = self.json_config | ||
| if is_hierarchical and "global" in global_config: | ||
| global_config = self.json_config["global"] | ||
|
|
||
| warp_num_per_block = global_config.get( | ||
| "warp_num_per_block", warp_num_per_block | ||
| ) | ||
| block_num = global_config.get("block_num", block_num) | ||
|
|
||
| config = EpDispatchCombineConfig( | ||
| data_type=data_type if quant_dtype is None else quant_dtype, | ||
| rank=self.rank, | ||
| world_size=self.world_size, | ||
| hidden_dim=hidden_dim, | ||
| max_num_inp_token_per_rank=max_num_tokens, | ||
| num_experts_per_rank=num_local_experts, | ||
| num_experts_per_token=experts_per_token, | ||
| max_token_type_size=data_type.itemsize, | ||
| # Performance tuning parameters | ||
| warp_num_per_block=warp_num_per_block, | ||
| block_num=block_num, | ||
| # Quantization support | ||
| scale_dim=scale_dim, | ||
| scale_type_size=scale_type_size, | ||
| # Determine kernel type based on topology | ||
| kernel_type=( | ||
| EpDispatchCombineKernelType.InterNode | ||
| if self.internode | ||
| else EpDispatchCombineKernelType.IntraNode | ||
| ), | ||
| ) | ||
|
|
||
| return config | ||
|
|
||
| def get_handle(self, kwargs): | ||
| """ | ||
| Get or create mori operation handle. | ||
| Args: | ||
| kwargs: Dictionary with keys: | ||
| - max_num_tokens: Maximum tokens per DP rank | ||
| - num_local_experts: Number of local experts | ||
| - experts_per_token: Number of experts per token (topk) | ||
| - hidden_dim: Hidden dimension size | ||
| - data_type: Tensor data type (optional, default bfloat16) | ||
| - scale_dim: Scale dimension (optional) | ||
| - scale_type_size: Scale type size (optional) | ||
| - ubatch_id: Microbatch ID (optional) | ||
| """ | ||
| # Ensure shmem is initialized before creating handles | ||
| self._ensure_shmem_initialized() | ||
|
|
||
| def create_mori_handle( | ||
| max_num_tokens: int, | ||
| num_local_experts: int, | ||
| experts_per_token: int, | ||
| hidden_dim: int, | ||
| scale_dim: int, | ||
| scale_type_size: int, | ||
| data_type: torch.dtype = torch.bfloat16, | ||
| quant_dtype: torch.dtype | None = None, | ||
| ): | ||
| from mori.ops import EpDispatchCombineOp | ||
|
|
||
| config = self._make_mori_config( | ||
| max_num_tokens=max_num_tokens, | ||
| num_local_experts=num_local_experts, | ||
| experts_per_token=experts_per_token, | ||
| hidden_dim=hidden_dim, | ||
| scale_dim=scale_dim, | ||
| scale_type_size=scale_type_size, | ||
| data_type=data_type, | ||
| quant_dtype=quant_dtype, | ||
| ) | ||
| op = EpDispatchCombineOp(config) | ||
| logger.debug( | ||
| "[rank %s] Created mori handle with config: tokens=%d, experts=%d," | ||
| " topk=%d, hidden_dim=%d", | ||
| self.dp_rank, | ||
| max_num_tokens, | ||
| num_local_experts, | ||
| experts_per_token, | ||
| hidden_dim, | ||
| ) | ||
| return op | ||
|
|
||
| return self.handle_cache.get_or_create(kwargs, create_mori_handle) | ||
|
|
||
| def dispatch( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| router_logits: torch.Tensor, | ||
| is_sequence_parallel: bool = False, | ||
| ): | ||
| raise NotImplementedError | ||
|
|
||
| def combine( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| is_sequence_parallel: bool = False, | ||
| ): | ||
| raise NotImplementedError | ||
|
|
||
| def destroy(self): | ||
| """Clean up mori resources""" | ||
| try: | ||
| # Clear operation handle cache | ||
| with self.handle_cache._lock: | ||
| for _, handle in self.handle_cache._cache.items(): | ||
| handle.destroy() | ||
|
|
||
| # finalize mori shared memory if it was initialized | ||
| if self._shmem_initialized: | ||
| try: | ||
| from mori.shmem import shmem_finalize | ||
|
|
||
| # Check if shmem is actually active before finalizing | ||
| shmem_finalize() | ||
| logger.debug("[rank %s] mori shmem finalize", self.dp_rank) | ||
| except Exception as shmem_error: | ||
| logger.debug( | ||
| "[rank %s] shmem finalize failed " | ||
| "(may not have been active): %s", | ||
| self.dp_rank, | ||
| shmem_error, | ||
| ) | ||
|
|
||
| logger.debug("[rank %s] mori resources cleaned up", self.dp_rank) | ||
|
|
||
| except Exception as e: | ||
| logger.warning("[rank %s] mori cleanup fail: %s", self.dp_rank, e) | ||
Uh oh!
There was an error while loading. Please reload this page.