Skip to content

Commit a00b38b

Browse files
committed
applied pre-commit results and Read the Docs build results
Signed-off-by: HakJu Kim <[email protected]>
1 parent ea17a69 commit a00b38b

File tree

7 files changed

+141
-97
lines changed

7 files changed

+141
-97
lines changed

vllm/distributed/device_communicators/all2all.py

Lines changed: 66 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from vllm.distributed import get_dp_group, get_ep_group
1010
from vllm.forward_context import get_forward_context
1111
from vllm.logger import init_logger
12-
from vllm.utils import has_deep_ep, has_pplx, has_mori
12+
from vllm.utils import has_deep_ep, has_mori, has_pplx
1313
from vllm.utils.flashinfer import has_flashinfer_all2all
1414

1515
from .base_device_communicator import All2AllManagerBase, Cache
@@ -439,24 +439,26 @@ def cleanup(self):
439439
self.mapping = None
440440
self.initialized = False
441441

442+
442443
class MoriAll2AllManager(All2AllManagerBase):
443444
"""
444445
All2All communication based on mori kernels.
445446
"""
447+
446448
def __init__(self, cpu_group):
447449
assert has_mori(
448-
), "mori not found. Please follow https://github.com/ROCm/mori/blob/main/README.md#installation to install mori." # noqa
450+
), "Please install mori from https://github.com/ROCm/mori."
449451

450452
super().__init__(cpu_group)
451453
self.handle_cache = Cache()
452454
self.config = None
453455
self._op_handles = {} # Cache for EpDispatchCombineOp instances
454456
self._shmem_initialized = False
455457
# Delay mori shmem initialization until first use
456-
logger.debug(f"[rank {self.rank}] MoriAll2AllManager created, shmem will be initialized lazily")
458+
logger.debug("[rank %s] MoriAll2AllManager created", self.rank)
457459

458460
def _ensure_shmem_initialized(self):
459-
"""Ensure mori's shared memory system is initialized (lazy initialization)"""
461+
"""Initialize mori's shared memory system lazily"""
460462
if self._shmem_initialized:
461463
return
462464

@@ -473,45 +475,60 @@ def _ensure_shmem_initialized(self):
473475
if backend is None:
474476
raise RuntimeError("No valid distributed backend found")
475477

476-
logger.debug(f"[rank {self.rank}] PyTorch distributed ready with backend: {backend}")
478+
logger.debug(
479+
"[rank %s] PyTorch distributed ready with backend: %s",
480+
self.rank, backend)
477481

478-
current_group = self.cpu_group if self.cpu_group is not None else dist.group.WORLD
482+
current_group = (self.cpu_group if self.cpu_group is not None else
483+
dist.group.WORLD)
479484

480485
# TODO(inhyeok): make group_name more reasonable
481486
group_name = "default"
482487
try:
488+
import contextlib
489+
483490
import torch._C._distributed_c10d as c10d
484491

485492
# Try to unregister first in case it exists
486-
try:
493+
with contextlib.suppress(RuntimeError):
487494
c10d._unregister_process_group(group_name)
488-
except:
489-
pass
490495

491496
# Register the current process group
492497
c10d._register_process_group(group_name, current_group)
493-
logger.debug(f"[rank {self.rank}] Registered process group '{group_name}'")
498+
logger.debug("[rank %s] Registered process group '%s'",
499+
self.rank, group_name)
494500

495501
# Initialize mori shmem with the registered group
496502
mori.shmem.shmem_torch_process_group_init(group_name)
497-
logger.debug(f"[rank {self.rank}] Torch process group shmem initialization successful")
503+
logger.debug(
504+
"[rank %s] torch process group shmem init success",
505+
self.rank)
498506
self._shmem_initialized = True
499507
return
500508

501509
except Exception as torch_error:
502-
logger.debug(f"[rank {self.rank}] Torch process group shmem init failed: {torch_error}")
510+
logger.debug(
511+
"[rank %s] torch process group shmem init failed: %s",
512+
self.rank, torch_error)
503513

504514
self._shmem_initialized = True
505515

506516
except Exception as e:
507-
logger.error(f"[rank {self.rank}] mori shmem initialization failed: {e}")
517+
logger.error("[rank %s] mori shmem initialization failed: %s",
518+
self.rank, e)
508519
# Don't fail completely - mark as initialized to avoid retry loops
509520
self._shmem_initialized = True
510-
logger.warning(f"[rank {self.rank}] Continuing without mori shmem optimization")
511-
512-
def _make_mori_config(self, max_num_tokens: int, num_local_experts: int,
513-
experts_per_token: int, hidden_dim: int,
514-
scale_dim: int, scale_type_size: int,
521+
logger.warning(
522+
"[rank %s] Continuing without mori shmem optimization",
523+
self.rank)
524+
525+
def _make_mori_config(self,
526+
max_num_tokens: int,
527+
num_local_experts: int,
528+
experts_per_token: int,
529+
hidden_dim: int,
530+
scale_dim: int,
531+
scale_type_size: int,
515532
data_type: torch.dtype = torch.bfloat16,
516533
quant_dtype: torch.dtype = None):
517534
"""Create mori EpDispatchCombineConfig"""
@@ -546,9 +563,8 @@ def _make_mori_config(self, max_num_tokens: int, num_local_experts: int,
546563

547564
# Determine kernel type based on topology
548565
kernel_type=(EpDispatchCombineKernelType.InterNode
549-
if self.internode
550-
else EpDispatchCombineKernelType.IntraNode)
551-
)
566+
if self.internode else
567+
EpDispatchCombineKernelType.IntraNode))
552568

553569
return config
554570

@@ -578,13 +594,16 @@ def get_handle(self, kwargs):
578594
scale_type_size = kwargs.get('scale_type_size')
579595

580596
# Validate required parameters
581-
if any(param is None for param in [max_num_tokens, num_local_experts,
582-
experts_per_token, hidden_dim]):
583-
raise ValueError("Missing required parameters for mori handle creation")
597+
if any(
598+
param is None for param in
599+
[max_num_tokens, num_local_experts, experts_per_token, hidden_dim
600+
]):
601+
raise ValueError(
602+
"Missing required parameters for mori handle creation")
584603

585604
# Create cache key
586605
cache_key = (max_num_tokens, num_local_experts, experts_per_token,
587-
hidden_dim, data_type)
606+
hidden_dim, data_type)
588607

589608
# Check cache first
590609
if cache_key in self._op_handles:
@@ -607,17 +626,22 @@ def get_handle(self, kwargs):
607626
# Cache the handle
608627
self._op_handles[cache_key] = op
609628

610-
logger.debug(f"[rank {self.dp_rank}] Created mori handle with config: "
611-
f"tokens={max_num_tokens}, experts={num_local_experts}, "
612-
f"topk={experts_per_token}, hidden={hidden_dim}")
629+
logger.debug(
630+
"[rank %s] Created mori handle with config: tokens=%d, experts=%d,"
631+
" topk=%d, hidden_dim=%d", self.dp_rank, max_num_tokens,
632+
num_local_experts, experts_per_token, hidden_dim)
613633

614634
return op
615635

616-
def dispatch(self, hidden_states: torch.Tensor,
617-
router_logits: torch.Tensor):
636+
def dispatch(self,
637+
hidden_states: torch.Tensor,
638+
router_logits: torch.Tensor,
639+
is_sequence_parallel: bool = False):
618640
raise NotImplementedError
619641

620-
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
642+
def combine(self,
643+
hidden_states: torch.Tensor,
644+
is_sequence_parallel: bool = False):
621645
raise NotImplementedError
622646

623647
def destroy(self):
@@ -626,17 +650,23 @@ def destroy(self):
626650
# Clear operation handle cache
627651
self._op_handles.clear()
628652

629-
# Try to finalize mori shared memory if it was successfully initialized
653+
# finalize mori shared memory if it was initialized
630654
if self._shmem_initialized:
631655
try:
632656
import mori.shmem
657+
633658
# Check if shmem is actually active before finalizing
634659
mori.shmem.shmem_finalize()
635-
logger.debug(f"[rank {self.dp_rank}] mori shmem finalized")
660+
logger.debug("[rank %s] mori shmem finalized",
661+
self.dp_rank)
636662
except Exception as shmem_error:
637-
logger.debug(f"[rank {self.dp_rank}] shmem finalize failed (may not have been active): {shmem_error}")
663+
logger.debug(
664+
"[rank %s] shmem finalize failed "
665+
"(may not have been active): %s", self.dp_rank,
666+
shmem_error)
638667

639-
logger.debug(f"[rank {self.dp_rank}] mori resources cleaned up")
668+
logger.debug("[rank %s] mori resources cleaned up", self.dp_rank)
640669

641670
except Exception as e:
642-
logger.warning(f"[rank {self.dp_rank}] Error during mori cleanup: {e}")
671+
logger.warning("[rank %s] Error during mori cleanup: %s",
672+
self.dp_rank, e)

vllm/distributed/device_communicators/base_device_communicator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
import torch
88
import torch.distributed as dist
99
from torch.distributed import ProcessGroup
10+
1011
from vllm.logger import init_logger
12+
1113
logger = init_logger(__name__)
1214

1315

vllm/model_executor/layers/fused_moe/aiter_experts.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
13
"""
24
Aiter-based expert processing for Mori integration.
35
"""
46

5-
from typing import Any, Optional
7+
from typing import Optional
68

79
import torch
810

911
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
1012
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
1113
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
12-
rocm_aiter_fused_experts,
13-
)
14+
rocm_aiter_fused_experts)
15+
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
16+
TopKWeightAndReduceNoOP)
1417

1518

1619
class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
@@ -24,11 +27,9 @@ class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
2427
def __init__(
2528
self,
2629
max_num_tokens: int,
27-
quant_config: FusedMoEQuantConfig = None,
30+
quant_config: FusedMoEQuantConfig,
2831
):
29-
super().__init__(
30-
quant_config=quant_config,
31-
)
32+
super().__init__(quant_config=quant_config, )
3233
self.max_num_tokens = max_num_tokens
3334

3435
@property
@@ -51,10 +52,6 @@ def supports_expert_map(self) -> bool:
5152

5253
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
5354
"""Aiter handles weight and reduce internally."""
54-
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
55-
TopKWeightAndReduceNoOP,
56-
)
57-
5855
return TopKWeightAndReduceNoOP()
5956

6057
def workspace_shapes(
@@ -101,6 +98,11 @@ def apply(
10198
Process expert computation using Aiter kernels.
10299
Works with pre-dispatched tokens from Mori all2all.
103100
"""
101+
if expert_tokens_meta is not None:
102+
expert_num_tokens = expert_tokens_meta.expert_num_tokens
103+
else:
104+
expert_num_tokens = None
105+
104106
# Call Aiter fused MoE expert processing
105107
result = rocm_aiter_fused_experts(
106108
hidden_states=hidden_states,
@@ -111,7 +113,7 @@ def apply(
111113
activation=activation,
112114
apply_router_weight_on_input=apply_router_weight_on_input,
113115
expert_map=expert_map,
114-
expert_num_tokens=expert_tokens_meta.expert_num_tokens if expert_tokens_meta is not None else None,
116+
expert_num_tokens=expert_num_tokens,
115117
output_dtype=output.dtype,
116118
quant_config=self.quant_config,
117119
a1q_scale=a1q_scale,

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@
4040
from vllm.model_executor.utils import set_weight_attrs
4141
from vllm.platforms import current_platform
4242
from vllm.platforms.interface import CpuArchEnum
43-
from vllm.utils import (cdiv, direct_register_custom_op, has_deep_ep, has_pplx,
44-
has_mori, round_up)
43+
from vllm.utils import (cdiv, direct_register_custom_op, has_deep_ep, has_mori,
44+
has_pplx, round_up)
4545
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
4646
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
4747

@@ -75,9 +75,12 @@ def _eplb_map_to_physical_and_record(
7575

7676
if is_rocm_aiter_moe_enabled():
7777
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
78-
rocm_aiter_grouped_topk as grouped_topk)
78+
rocm_aiter_grouped_topk)
79+
grouped_topk_impl = rocm_aiter_grouped_topk
7980
else:
8081
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
82+
grouped_topk_impl = grouped_topk
83+
8184
if current_platform.is_tpu():
8285
from .moe_pallas import fused_moe as fused_moe_pallas
8386
else:
@@ -210,21 +213,20 @@ def _maybe_make_prepare_finalize(
210213
use_fp8_dispatch=use_fp8_dispatch,
211214
)
212215
elif moe.use_mori_kernels:
213-
use_fp8_dispatch = (
214-
quant_config is not None
215-
and quant_config.quant_dtype == current_platform.fp8_dtype()
216-
)
216+
use_fp8_dispatch = (quant_config is not None
217+
and quant_config.quant_dtype
218+
== current_platform.fp8_dtype())
217219
scale_dim = 0
218220
scale_type_size = 0
219221
quant_dtype = None
220222
if use_fp8_dispatch:
223+
assert quant_config is not None
221224
scale_dim = quant_config.scale_shape(
222225
moe.max_num_tokens,
223226
moe.hidden_dim,
224227
)[-1]
225-
scale_type_size = (
226-
torch.float32.itemsize
227-
) # aiter quantization uses float32 scale
228+
scale_type_size = (torch.float32.itemsize
229+
) # aiter quantization uses float32 scale
228230
quant_dtype = quant_config.quant_dtype
229231

230232
all_to_all_args = dict(
@@ -394,7 +396,7 @@ def select_gemm_impl(
394396
quant_config=self.moe_quant_config,
395397
)
396398
elif (prepare_finalize.activation_format ==
397-
FusedMoEActivationFormat.BatchedExperts):
399+
FusedMoEActivationFormat.BatchedExperts):
398400
logger.debug("BatchedTritonExperts %s", self.moe)
399401
return BatchedTritonExperts(
400402
max_num_tokens=self.moe.max_num_tokens,
@@ -1760,7 +1762,7 @@ def select_experts(
17601762
if use_grouped_topk:
17611763
assert topk_group is not None
17621764
assert num_expert_group is not None
1763-
topk_weights, topk_ids = grouped_topk(
1765+
topk_weights, topk_ids = grouped_topk_impl(
17641766
hidden_states=hidden_states,
17651767
gating_output=router_logits,
17661768
topk=top_k,

0 commit comments

Comments
 (0)