diff --git a/tests/ut/ops/test_token_dispatcher.py b/tests/ut/ops/test_token_dispatcher.py index 15fd094f76d..26a741c39c5 100644 --- a/tests/ut/ops/test_token_dispatcher.py +++ b/tests/ut/ops/test_token_dispatcher.py @@ -70,7 +70,6 @@ def test_init(self): self.assertFalse(self.dispatcher.with_quant) self.assertTrue(self.dispatcher.enable_dispatch_v2) self.assertTrue(self.dispatcher.need_extra_args) - self.assertTrue(self.dispatcher.a3_need_extra_args) def test_get_dispatch_mc2_kwargs_without_quant(self): hidden_states = torch.randn(10, 128) diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index 80c7fcc2c8c..aa039173235 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -116,7 +116,11 @@ def apply(self, # to avoid accumulating too much tokens on a single rank. # currently it is only activated when doing profile runs. if enable_force_load_balance: - topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) + random_matrix = torch.rand(topk_ids.size(0), + global_num_experts, + device=topk_ids.device) + topk_ids = torch.argsort( + random_matrix, dim=1)[:, :topk_ids.size(1)].to(topk_ids.dtype) moe_comm_method = get_forward_context().moe_comm_method return moe_comm_method.fused_experts( diff --git a/vllm_ascend/ops/fused_moe/token_dispatcher.py b/vllm_ascend/ops/fused_moe/token_dispatcher.py index 75035a47a62..e45504d9b8d 100644 --- a/vllm_ascend/ops/fused_moe/token_dispatcher.py +++ b/vllm_ascend/ops/fused_moe/token_dispatcher.py @@ -25,6 +25,7 @@ import torch import torch_npu +from vllm.config import get_current_vllm_config from vllm.distributed.parallel_state import get_ep_group from vllm_ascend.distributed.parallel_state import get_mc2_group @@ -100,15 +101,31 @@ def __init__(self, **kwargs): self.need_extra_args = ( get_ascend_device_type() == AscendDeviceType._910_93) - # NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine - self.a3_need_extra_args = \ - get_ascend_device_type() == AscendDeviceType._910_93 # NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and # HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly # improve communication performance. self.need_expert_scale = is_hierarchical_communication_enabled() self.with_quant = False + # Here we need to calculate the global_bs = max_bs_per_rank * ep_world_size to execute + # dispatch & combine operators with different input num_tokens per rank. + vllm_config = get_current_vllm_config() + scheduler_config = vllm_config.scheduler_config + compilation_config = vllm_config.compilation_config + speculative_config = vllm_config.speculative_config + tp_size = vllm_config.parallel_config.tensor_parallel_size + uniform_decode_query_len = 1 if not speculative_config else \ + 1 + speculative_config.num_speculative_tokens + decode_max_num_seqs = getattr(scheduler_config, 'decode_max_num_seqs', + 0) + max_num_reqs = max(scheduler_config.max_num_seqs, decode_max_num_seqs) + if compilation_config.cudagraph_capture_sizes: + max_num_tokens = compilation_config.max_cudagraph_capture_size + else: + max_num_tokens = min(max_num_reqs * uniform_decode_query_len, 512) + num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size + self.global_bs = num_tokens_per_tp_rank * self.ep_world_size + def get_dispatch_mc2_kwargs( self, hidden_states: torch.Tensor, @@ -130,7 +147,7 @@ def get_dispatch_mc2_kwargs( "expert_shard_type": 0, "shared_expert_rank_num": 0, "moe_expert_num": moe_expert_num, - "global_bs": 0, + "global_bs": self.global_bs, "expert_token_nums_type": 0, } @@ -147,10 +164,6 @@ def get_dispatch_mc2_kwargs( "tp_world_size": 1, "tp_rank_id": 0, }) - if self.a3_need_extra_args and self.enable_dispatch_v2: - stage1_kwargs.update({ - "x_active_mask": mc2_mask, - }) if self.need_expert_scale: stage1_kwargs.update({ "expert_scales": @@ -214,7 +227,6 @@ def token_dispatch(self, context_metadata = { "topk_ids": topk_ids, "topk_weights": topk_weights, - "mc2_mask": mc2_mask, "expert_map": expert_map, "ep_recv_counts": ep_recv_counts, "tp_recv_counts": tp_recv_counts, @@ -243,7 +255,6 @@ def get_combine_mc_kwargs(self, hidden_states: torch.Tensor, ep_recv_counts = context_metadata["ep_recv_counts"] tp_recv_counts = context_metadata["tp_recv_counts"] assist_info_for_combine = context_metadata["assist_info_for_combine"] - mc2_mask = context_metadata["mc2_mask"] expand_scales = context_metadata["expand_scales"] assert expert_map is not None @@ -256,7 +267,7 @@ def get_combine_mc_kwargs(self, hidden_states: torch.Tensor, "expert_shard_type": 0, "shared_expert_rank_num": 0, "moe_expert_num": moe_expert_num, - "global_bs": 0, + "global_bs": self.global_bs, } if self.with_quant: @@ -285,9 +296,6 @@ def get_combine_mc_kwargs(self, hidden_states: torch.Tensor, "tp_rank_id": 0, }) - if self.a3_need_extra_args and self.enable_dispatch_v2: - stage3_kwargs["x_active_mask"] = mc2_mask - kwargs_mc2.update(stage3_kwargs) return kwargs_mc2 diff --git a/vllm_ascend/quantization/w4a8_dynamic.py b/vllm_ascend/quantization/w4a8_dynamic.py index a73050c3123..dc5f580d47b 100644 --- a/vllm_ascend/quantization/w4a8_dynamic.py +++ b/vllm_ascend/quantization/w4a8_dynamic.py @@ -371,8 +371,12 @@ def apply( # to avoid accumulating too much tokens on a single rank. # currently it is only activated when doing profile runs. if enable_force_load_balance: - topk_ids = torch.randint_like( - topk_ids, 0, global_num_experts - global_redundant_expert_num) + random_matrix = torch.rand(topk_ids.size(0), + global_num_experts - + global_redundant_expert_num, + device=topk_ids.device) + topk_ids = torch.argsort( + random_matrix, dim=1)[:, :topk_ids.size(1)].to(topk_ids.dtype) topk_weights = topk_weights.to(x.dtype) diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 2e86dd6e74c..bde7aed2aa4 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -216,8 +216,12 @@ def apply( # to avoid accumulating too much tokens on a single rank. # currently it is only activated when doing profile runs. if enable_force_load_balance: - topk_ids = torch.randint_like( - topk_ids, 0, global_num_experts - global_redundant_expert_num) + random_matrix = torch.rand(topk_ids.size(0), + global_num_experts - + global_redundant_expert_num, + device=topk_ids.device) + topk_ids = torch.argsort( + random_matrix, dim=1)[:, :topk_ids.size(1)].to(topk_ids.dtype) topk_weights = topk_weights.to(self.in_dtype) diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 08930190974..9dce620ea4f 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -807,10 +807,7 @@ def _propose( num_indices = last_token_indices.shape[0] if lmhead_tp_enable(): - if not self.runner.with_prefill: - max_num_reqs_across_dp = num_input_tokens - else: - max_num_reqs_across_dp = self.vllm_config.scheduler_config.max_num_seqs + max_num_reqs_across_dp = self.vllm_config.scheduler_config.max_num_seqs * self.runner.uniform_decode_query_len last_token_indices = nn.functional.pad( last_token_indices, (0, max_num_reqs_across_dp - num_indices)) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index f6fded7d872..862612d1377 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -954,6 +954,25 @@ def _init_mrope_positions(self, req_state: CachedRequestState): req_state.mm_features, ) + def _skip_all_reduce_acorss_dp_group(self) -> bool: + # NOTE: We can skip the all_reduce operation and avoid paading tokens + # to max_tokens_acrodd_dp in D nodes. In MoE models, we must ensure that + # num_tokens DOES NOT exceed mc2_tokens_capacity which means that moe_comm_method + # of each rank is MC2. For dense models, skipping all_reduce is not necessary + # since collective-communication is not time-consuming since dp_size in dense + # model deployments is always small and can be overlapped by async scheduling. + if not is_moe_model(self.vllm_config): + return False + if self.compilation_config.cudagraph_capture_sizes: + potential_max_num_tokens = self.compilation_config.max_cudagraph_capture_size + else: + potential_max_num_tokens = self.max_num_reqs * self.uniform_decode_query_len + # To ensure skipping all_reduce across dp group is valid, we need to ensure that + # moe_comm_method of each rank is MC2 and recomputation would never happen in D + # nodes. So here we check whether recompute_scheduler_enable is True. + return self.is_kv_consumer and not self.in_profile_run and self.ascend_config.recompute_scheduler_enable and self._select_moe_comm_method( + potential_max_num_tokens) == MoECommType.MC2 + def _sync_metadata_across_dp( self, num_tokens: int, with_prefill: bool) -> tuple[int, Optional[torch.Tensor], bool]: @@ -965,6 +984,14 @@ def _sync_metadata_across_dp( # immediately once the other two flags are no longer needed. if self.dp_size == 1: return num_tokens, None, with_prefill + + if self._skip_all_reduce_acorss_dp_group(): + num_tokens_after_padding = torch.tensor([num_tokens] * + self.dp_size, + device="cpu", + dtype=torch.int32) + return num_tokens, num_tokens_after_padding, with_prefill + # Sync num_tokens, with_prefill across dp ranks num_tokens_tensor = torch.tensor([ num_tokens if i == self.dp_rank else 0 for i in range(self.dp_size) @@ -1959,7 +1986,7 @@ def _prepare_inputs( attn_metadata[layer_name] = attn_metadata_i if lmhead_tp_enable(): - max_num_reqs_across_dp = maybe_padded_num_tokens if not with_prefill else self.max_num_reqs + max_num_reqs_across_dp = self.max_num_reqs * self.uniform_decode_query_len logits_indices = nn.functional.pad( logits_indices, (0, max_num_reqs_across_dp - logits_indices.shape[0])) @@ -3102,7 +3129,7 @@ def _dummy_run( need_dummy_logits = (not self.in_profile_run and lmhead_tp_enable()) - max_num_reqs_across_dp = num_tokens_padded if not with_prefill else max_num_reqs + max_num_reqs_across_dp = max_num_reqs * self.uniform_decode_query_len dummy_indices = torch.zeros(max_num_reqs_across_dp, dtype=torch.int32)