99from vllm .distributed import get_dp_group , get_ep_group
1010from vllm .forward_context import get_forward_context
1111from vllm .logger import init_logger
12- from vllm .utils import has_deep_ep , has_pplx , has_mori
12+ from vllm .utils import has_deep_ep , has_mori , has_pplx
1313from vllm .utils .flashinfer import has_flashinfer_all2all
1414
1515from .base_device_communicator import All2AllManagerBase , Cache
@@ -439,24 +439,26 @@ def cleanup(self):
439439 self .mapping = None
440440 self .initialized = False
441441
442+
442443class MoriAll2AllManager (All2AllManagerBase ):
443444 """
444445 All2All communication based on mori kernels.
445446 """
447+
446448 def __init__ (self , cpu_group ):
447449 assert has_mori (
448- ), "mori not found. Please follow https://github.com/ROCm/mori/blob/main/README.md#installation to install mori." # noqa
450+ ), "Please install mori from https://github.com/ROCm/mori."
449451
450452 super ().__init__ (cpu_group )
451453 self .handle_cache = Cache ()
452454 self .config = None
453455 self ._op_handles = {} # Cache for EpDispatchCombineOp instances
454456 self ._shmem_initialized = False
455457 # Delay mori shmem initialization until first use
456- logger .debug (f "[rank { self . rank } ] MoriAll2AllManager created, shmem will be initialized lazily" )
458+ logger .debug ("[rank %s ] MoriAll2AllManager created" , self . rank )
457459
458460 def _ensure_shmem_initialized (self ):
459- """Ensure mori's shared memory system is initialized (lazy initialization) """
461+ """Initialize mori's shared memory system lazily """
460462 if self ._shmem_initialized :
461463 return
462464
@@ -473,45 +475,60 @@ def _ensure_shmem_initialized(self):
473475 if backend is None :
474476 raise RuntimeError ("No valid distributed backend found" )
475477
476- logger .debug (f"[rank { self .rank } ] PyTorch distributed ready with backend: { backend } " )
478+ logger .debug (
479+ "[rank %s] PyTorch distributed ready with backend: %s" ,
480+ self .rank , backend )
477481
478- current_group = self .cpu_group if self .cpu_group is not None else dist .group .WORLD
482+ current_group = (self .cpu_group if self .cpu_group is not None else
483+ dist .group .WORLD )
479484
480485 # TODO(inhyeok): make group_name more reasonable
481486 group_name = "default"
482487 try :
488+ import contextlib
489+
483490 import torch ._C ._distributed_c10d as c10d
484491
485492 # Try to unregister first in case it exists
486- try :
493+ with contextlib . suppress ( RuntimeError ) :
487494 c10d ._unregister_process_group (group_name )
488- except :
489- pass
490495
491496 # Register the current process group
492497 c10d ._register_process_group (group_name , current_group )
493- logger .debug (f"[rank { self .rank } ] Registered process group '{ group_name } '" )
498+ logger .debug ("[rank %s] Registered process group '%s'" ,
499+ self .rank , group_name )
494500
495501 # Initialize mori shmem with the registered group
496502 mori .shmem .shmem_torch_process_group_init (group_name )
497- logger .debug (f"[rank { self .rank } ] Torch process group shmem initialization successful" )
503+ logger .debug (
504+ "[rank %s] torch process group shmem init success" ,
505+ self .rank )
498506 self ._shmem_initialized = True
499507 return
500508
501509 except Exception as torch_error :
502- logger .debug (f"[rank { self .rank } ] Torch process group shmem init failed: { torch_error } " )
510+ logger .debug (
511+ "[rank %s] torch process group shmem init failed: %s" ,
512+ self .rank , torch_error )
503513
504514 self ._shmem_initialized = True
505515
506516 except Exception as e :
507- logger .error (f"[rank { self .rank } ] mori shmem initialization failed: { e } " )
517+ logger .error ("[rank %s] mori shmem initialization failed: %s" ,
518+ self .rank , e )
508519 # Don't fail completely - mark as initialized to avoid retry loops
509520 self ._shmem_initialized = True
510- logger .warning (f"[rank { self .rank } ] Continuing without mori shmem optimization" )
511-
512- def _make_mori_config (self , max_num_tokens : int , num_local_experts : int ,
513- experts_per_token : int , hidden_dim : int ,
514- scale_dim : int , scale_type_size : int ,
521+ logger .warning (
522+ "[rank %s] Continuing without mori shmem optimization" ,
523+ self .rank )
524+
525+ def _make_mori_config (self ,
526+ max_num_tokens : int ,
527+ num_local_experts : int ,
528+ experts_per_token : int ,
529+ hidden_dim : int ,
530+ scale_dim : int ,
531+ scale_type_size : int ,
515532 data_type : torch .dtype = torch .bfloat16 ,
516533 quant_dtype : torch .dtype = None ):
517534 """Create mori EpDispatchCombineConfig"""
@@ -546,9 +563,8 @@ def _make_mori_config(self, max_num_tokens: int, num_local_experts: int,
546563
547564 # Determine kernel type based on topology
548565 kernel_type = (EpDispatchCombineKernelType .InterNode
549- if self .internode
550- else EpDispatchCombineKernelType .IntraNode )
551- )
566+ if self .internode else
567+ EpDispatchCombineKernelType .IntraNode ))
552568
553569 return config
554570
@@ -578,13 +594,16 @@ def get_handle(self, kwargs):
578594 scale_type_size = kwargs .get ('scale_type_size' )
579595
580596 # Validate required parameters
581- if any (param is None for param in [max_num_tokens , num_local_experts ,
582- experts_per_token , hidden_dim ]):
583- raise ValueError ("Missing required parameters for mori handle creation" )
597+ if any (
598+ param is None for param in
599+ [max_num_tokens , num_local_experts , experts_per_token , hidden_dim
600+ ]):
601+ raise ValueError (
602+ "Missing required parameters for mori handle creation" )
584603
585604 # Create cache key
586605 cache_key = (max_num_tokens , num_local_experts , experts_per_token ,
587- hidden_dim , data_type )
606+ hidden_dim , data_type )
588607
589608 # Check cache first
590609 if cache_key in self ._op_handles :
@@ -607,17 +626,22 @@ def get_handle(self, kwargs):
607626 # Cache the handle
608627 self ._op_handles [cache_key ] = op
609628
610- logger .debug (f"[rank { self .dp_rank } ] Created mori handle with config: "
611- f"tokens={ max_num_tokens } , experts={ num_local_experts } , "
612- f"topk={ experts_per_token } , hidden={ hidden_dim } " )
629+ logger .debug (
630+ "[rank %s] Created mori handle with config: tokens=%d, experts=%d,"
631+ " topk=%d, hidden_dim=%d" , self .dp_rank , max_num_tokens ,
632+ num_local_experts , experts_per_token , hidden_dim )
613633
614634 return op
615635
616- def dispatch (self , hidden_states : torch .Tensor ,
617- router_logits : torch .Tensor ):
636+ def dispatch (self ,
637+ hidden_states : torch .Tensor ,
638+ router_logits : torch .Tensor ,
639+ is_sequence_parallel : bool = False ):
618640 raise NotImplementedError
619641
620- def combine (self , hidden_states : torch .Tensor ) -> torch .Tensor :
642+ def combine (self ,
643+ hidden_states : torch .Tensor ,
644+ is_sequence_parallel : bool = False ):
621645 raise NotImplementedError
622646
623647 def destroy (self ):
@@ -626,17 +650,23 @@ def destroy(self):
626650 # Clear operation handle cache
627651 self ._op_handles .clear ()
628652
629- # Try to finalize mori shared memory if it was successfully initialized
653+ # finalize mori shared memory if it was initialized
630654 if self ._shmem_initialized :
631655 try :
632656 import mori .shmem
657+
633658 # Check if shmem is actually active before finalizing
634659 mori .shmem .shmem_finalize ()
635- logger .debug (f"[rank { self .dp_rank } ] mori shmem finalized" )
660+ logger .debug ("[rank %s] mori shmem finalized" ,
661+ self .dp_rank )
636662 except Exception as shmem_error :
637- logger .debug (f"[rank { self .dp_rank } ] shmem finalize failed (may not have been active): { shmem_error } " )
663+ logger .debug (
664+ "[rank %s] shmem finalize failed "
665+ "(may not have been active): %s" , self .dp_rank ,
666+ shmem_error )
638667
639- logger .debug (f "[rank { self . dp_rank } ] mori resources cleaned up" )
668+ logger .debug ("[rank %s ] mori resources cleaned up" , self . dp_rank )
640669
641670 except Exception as e :
642- logger .warning (f"[rank { self .dp_rank } ] Error during mori cleanup: { e } " )
671+ logger .warning ("[rank %s] Error during mori cleanup: %s" ,
672+ self .dp_rank , e )
0 commit comments