Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion tests/ut/ops/test_token_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def test_init(self):
self.assertFalse(self.dispatcher.with_quant)
self.assertTrue(self.dispatcher.enable_dispatch_v2)
self.assertTrue(self.dispatcher.need_extra_args)
self.assertTrue(self.dispatcher.a3_need_extra_args)

def test_get_dispatch_mc2_kwargs_without_quant(self):
hidden_states = torch.randn(10, 128)
Expand Down
6 changes: 5 additions & 1 deletion vllm_ascend/ops/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,11 @@ def apply(self,
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
if enable_force_load_balance:
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
random_matrix = torch.rand(topk_ids.size(0),
global_num_experts,
device=topk_ids.device)
topk_ids = torch.argsort(
random_matrix, dim=1)[:, :topk_ids.size(1)].to(topk_ids.dtype)

moe_comm_method = get_forward_context().moe_comm_method
return moe_comm_method.fused_experts(
Expand Down
36 changes: 22 additions & 14 deletions vllm_ascend/ops/fused_moe/token_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import torch
import torch_npu
from vllm.config import get_current_vllm_config
from vllm.distributed.parallel_state import get_ep_group

from vllm_ascend.distributed.parallel_state import get_mc2_group
Expand Down Expand Up @@ -100,15 +101,31 @@ def __init__(self, **kwargs):
self.need_extra_args = (
get_ascend_device_type() == AscendDeviceType._910_93)

# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
self.a3_need_extra_args = \
get_ascend_device_type() == AscendDeviceType._910_93
# NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and
# HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly
# improve communication performance.
self.need_expert_scale = is_hierarchical_communication_enabled()
self.with_quant = False

# Here we need to calculate the global_bs = max_bs_per_rank * ep_world_size to execute
# dispatch & combine operators with different input num_tokens per rank.
vllm_config = get_current_vllm_config()
scheduler_config = vllm_config.scheduler_config
compilation_config = vllm_config.compilation_config
speculative_config = vllm_config.speculative_config
tp_size = vllm_config.parallel_config.tensor_parallel_size
uniform_decode_query_len = 1 if not speculative_config else \
1 + speculative_config.num_speculative_tokens
decode_max_num_seqs = getattr(scheduler_config, 'decode_max_num_seqs',
0)
max_num_reqs = max(scheduler_config.max_num_seqs, decode_max_num_seqs)
if compilation_config.cudagraph_capture_sizes:
max_num_tokens = compilation_config.max_cudagraph_capture_size
else:
max_num_tokens = min(max_num_reqs * uniform_decode_query_len, 512)
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
self.global_bs = num_tokens_per_tp_rank * self.ep_world_size

def get_dispatch_mc2_kwargs(
self,
hidden_states: torch.Tensor,
Expand All @@ -130,7 +147,7 @@ def get_dispatch_mc2_kwargs(
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"moe_expert_num": moe_expert_num,
"global_bs": 0,
"global_bs": self.global_bs,
"expert_token_nums_type": 0,
}

Expand All @@ -147,10 +164,6 @@ def get_dispatch_mc2_kwargs(
"tp_world_size": 1,
"tp_rank_id": 0,
})
if self.a3_need_extra_args and self.enable_dispatch_v2:
stage1_kwargs.update({
"x_active_mask": mc2_mask,
})
if self.need_expert_scale:
stage1_kwargs.update({
"expert_scales":
Expand Down Expand Up @@ -214,7 +227,6 @@ def token_dispatch(self,
context_metadata = {
"topk_ids": topk_ids,
"topk_weights": topk_weights,
"mc2_mask": mc2_mask,
"expert_map": expert_map,
"ep_recv_counts": ep_recv_counts,
"tp_recv_counts": tp_recv_counts,
Expand Down Expand Up @@ -243,7 +255,6 @@ def get_combine_mc_kwargs(self, hidden_states: torch.Tensor,
ep_recv_counts = context_metadata["ep_recv_counts"]
tp_recv_counts = context_metadata["tp_recv_counts"]
assist_info_for_combine = context_metadata["assist_info_for_combine"]
mc2_mask = context_metadata["mc2_mask"]
expand_scales = context_metadata["expand_scales"]

assert expert_map is not None
Expand All @@ -256,7 +267,7 @@ def get_combine_mc_kwargs(self, hidden_states: torch.Tensor,
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"moe_expert_num": moe_expert_num,
"global_bs": 0,
"global_bs": self.global_bs,
}

if self.with_quant:
Expand Down Expand Up @@ -285,9 +296,6 @@ def get_combine_mc_kwargs(self, hidden_states: torch.Tensor,
"tp_rank_id": 0,
})

if self.a3_need_extra_args and self.enable_dispatch_v2:
stage3_kwargs["x_active_mask"] = mc2_mask

kwargs_mc2.update(stage3_kwargs)
return kwargs_mc2

Expand Down
8 changes: 6 additions & 2 deletions vllm_ascend/quantization/w4a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,8 +371,12 @@ def apply(
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
if enable_force_load_balance:
topk_ids = torch.randint_like(
topk_ids, 0, global_num_experts - global_redundant_expert_num)
random_matrix = torch.rand(topk_ids.size(0),
global_num_experts -
global_redundant_expert_num,
device=topk_ids.device)
topk_ids = torch.argsort(
random_matrix, dim=1)[:, :topk_ids.size(1)].to(topk_ids.dtype)

topk_weights = topk_weights.to(x.dtype)

Expand Down
8 changes: 6 additions & 2 deletions vllm_ascend/quantization/w8a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,12 @@ def apply(
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
if enable_force_load_balance:
topk_ids = torch.randint_like(
topk_ids, 0, global_num_experts - global_redundant_expert_num)
random_matrix = torch.rand(topk_ids.size(0),
global_num_experts -
global_redundant_expert_num,
device=topk_ids.device)
topk_ids = torch.argsort(
random_matrix, dim=1)[:, :topk_ids.size(1)].to(topk_ids.dtype)

topk_weights = topk_weights.to(self.in_dtype)

Expand Down
5 changes: 1 addition & 4 deletions vllm_ascend/spec_decode/mtp_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,10 +807,7 @@ def _propose(

num_indices = last_token_indices.shape[0]
if lmhead_tp_enable():
if not self.runner.with_prefill:
max_num_reqs_across_dp = num_input_tokens
else:
max_num_reqs_across_dp = self.vllm_config.scheduler_config.max_num_seqs
max_num_reqs_across_dp = self.vllm_config.scheduler_config.max_num_seqs * self.runner.uniform_decode_query_len
last_token_indices = nn.functional.pad(
last_token_indices,
(0, max_num_reqs_across_dp - num_indices))
Expand Down
31 changes: 29 additions & 2 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,6 +954,25 @@ def _init_mrope_positions(self, req_state: CachedRequestState):
req_state.mm_features,
)

def _skip_all_reduce_acorss_dp_group(self) -> bool:
# NOTE: We can skip the all_reduce operation and avoid paading tokens
# to max_tokens_acrodd_dp in D nodes. In MoE models, we must ensure that
# num_tokens DOES NOT exceed mc2_tokens_capacity which means that moe_comm_method
# of each rank is MC2. For dense models, skipping all_reduce is not necessary
# since collective-communication is not time-consuming since dp_size in dense
# model deployments is always small and can be overlapped by async scheduling.
if not is_moe_model(self.vllm_config):
return False
if self.compilation_config.cudagraph_capture_sizes:
potential_max_num_tokens = self.compilation_config.max_cudagraph_capture_size
else:
potential_max_num_tokens = self.max_num_reqs * self.uniform_decode_query_len
# To ensure skipping all_reduce across dp group is valid, we need to ensure that
# moe_comm_method of each rank is MC2 and recomputation would never happen in D
# nodes. So here we check whether recompute_scheduler_enable is True.
return self.is_kv_consumer and not self.in_profile_run and self.ascend_config.recompute_scheduler_enable and self._select_moe_comm_method(
potential_max_num_tokens) == MoECommType.MC2

def _sync_metadata_across_dp(
self, num_tokens: int,
with_prefill: bool) -> tuple[int, Optional[torch.Tensor], bool]:
Expand All @@ -965,6 +984,14 @@ def _sync_metadata_across_dp(
# immediately once the other two flags are no longer needed.
if self.dp_size == 1:
return num_tokens, None, with_prefill

if self._skip_all_reduce_acorss_dp_group():
num_tokens_after_padding = torch.tensor([num_tokens] *
self.dp_size,
device="cpu",
dtype=torch.int32)
return num_tokens, num_tokens_after_padding, with_prefill

# Sync num_tokens, with_prefill across dp ranks
num_tokens_tensor = torch.tensor([
num_tokens if i == self.dp_rank else 0 for i in range(self.dp_size)
Expand Down Expand Up @@ -1959,7 +1986,7 @@ def _prepare_inputs(
attn_metadata[layer_name] = attn_metadata_i

if lmhead_tp_enable():
max_num_reqs_across_dp = maybe_padded_num_tokens if not with_prefill else self.max_num_reqs
max_num_reqs_across_dp = self.max_num_reqs * self.uniform_decode_query_len
logits_indices = nn.functional.pad(
logits_indices,
(0, max_num_reqs_across_dp - logits_indices.shape[0]))
Expand Down Expand Up @@ -3102,7 +3129,7 @@ def _dummy_run(

need_dummy_logits = (not self.in_profile_run
and lmhead_tp_enable())
max_num_reqs_across_dp = num_tokens_padded if not with_prefill else max_num_reqs
max_num_reqs_across_dp = max_num_reqs * self.uniform_decode_query_len
dummy_indices = torch.zeros(max_num_reqs_across_dp,
dtype=torch.int32)

Expand Down
Loading