Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion vllm_ascend/ops/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,11 @@ def apply(self,
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
if enable_force_load_balance:
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
random_matrix = torch.rand(topk_ids.size(0),
global_num_experts,
device=topk_ids.device)
topk_ids = torch.argsort(
random_matrix, dim=1)[:, :topk_ids.size(1)].to(topk_ids.dtype)

moe_comm_method = get_forward_context().moe_comm_method
return moe_comm_method.fused_experts(
Expand Down
36 changes: 22 additions & 14 deletions vllm_ascend/ops/fused_moe/token_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import torch
import torch_npu
from vllm.config import get_current_vllm_config
from vllm.distributed.parallel_state import get_ep_group

from vllm_ascend.distributed.parallel_state import get_mc2_group
Expand Down Expand Up @@ -100,15 +101,31 @@ def __init__(self, **kwargs):
self.need_extra_args = (
get_ascend_device_type() == AscendDeviceType._910_93)

# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
self.a3_need_extra_args = \
get_ascend_device_type() == AscendDeviceType._910_93
# NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and
# HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly
# improve communication performance.
self.need_expert_scale = is_hierarchical_communication_enabled()
self.with_quant = False

# Here we need to calculate the global_bs = max_bs_per_rank * ep_world_size to execute
# dispatch & combine operators with different input num_tokens per rank.
vllm_config = get_current_vllm_config()
scheduler_config = vllm_config.scheduler_config
compilation_config = vllm_config.compilation_config
speculative_config = vllm_config.speculative_config
tp_size = vllm_config.parallel_config.tensor_parallel_size
uniform_decode_query_len = 1 if not speculative_config else \
1 + speculative_config.num_speculative_tokens
decode_max_num_seqs = getattr(scheduler_config, 'decode_max_num_seqs',
0)
max_num_reqs = max(scheduler_config.max_num_seqs, decode_max_num_seqs)
if compilation_config.cudagraph_capture_sizes:
max_num_tokens = compilation_config.max_cudagraph_capture_size
else:
max_num_tokens = min(max_num_reqs * uniform_decode_query_len, 512)
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
self.global_bs = num_tokens_per_tp_rank * self.ep_world_size

def get_dispatch_mc2_kwargs(
self,
hidden_states: torch.Tensor,
Expand All @@ -130,7 +147,7 @@ def get_dispatch_mc2_kwargs(
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"moe_expert_num": moe_expert_num,
"global_bs": 0,
"global_bs": self.global_bs,
"expert_token_nums_type": 0,
}

Expand All @@ -147,10 +164,6 @@ def get_dispatch_mc2_kwargs(
"tp_world_size": 1,
"tp_rank_id": 0,
})
if self.a3_need_extra_args and self.enable_dispatch_v2:
stage1_kwargs.update({
"x_active_mask": mc2_mask,
})
if self.need_expert_scale:
stage1_kwargs.update({
"expert_scales":
Expand Down Expand Up @@ -214,7 +227,6 @@ def token_dispatch(self,
context_metadata = {
"topk_ids": topk_ids,
"topk_weights": topk_weights,
"mc2_mask": mc2_mask,
"expert_map": expert_map,
"ep_recv_counts": ep_recv_counts,
"tp_recv_counts": tp_recv_counts,
Expand Down Expand Up @@ -243,7 +255,6 @@ def get_combine_mc_kwargs(self, hidden_states: torch.Tensor,
ep_recv_counts = context_metadata["ep_recv_counts"]
tp_recv_counts = context_metadata["tp_recv_counts"]
assist_info_for_combine = context_metadata["assist_info_for_combine"]
mc2_mask = context_metadata["mc2_mask"]
expand_scales = context_metadata["expand_scales"]

assert expert_map is not None
Expand All @@ -256,7 +267,7 @@ def get_combine_mc_kwargs(self, hidden_states: torch.Tensor,
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"moe_expert_num": moe_expert_num,
"global_bs": 0,
"global_bs": self.global_bs,
}

if self.with_quant:
Expand Down Expand Up @@ -285,9 +296,6 @@ def get_combine_mc_kwargs(self, hidden_states: torch.Tensor,
"tp_rank_id": 0,
})

if self.a3_need_extra_args and self.enable_dispatch_v2:
stage3_kwargs["x_active_mask"] = mc2_mask

kwargs_mc2.update(stage3_kwargs)
return kwargs_mc2

Expand Down
8 changes: 6 additions & 2 deletions vllm_ascend/quantization/w4a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,8 +371,12 @@ def apply(
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
if enable_force_load_balance:
topk_ids = torch.randint_like(
topk_ids, 0, global_num_experts - global_redundant_expert_num)
random_matrix = torch.rand(topk_ids.size(0),
global_num_experts -
global_redundant_expert_num,
device=topk_ids.device)
topk_ids = torch.argsort(
random_matrix, dim=1)[:, :topk_ids.size(1)].to(topk_ids.dtype)

topk_weights = topk_weights.to(x.dtype)

Expand Down
8 changes: 6 additions & 2 deletions vllm_ascend/quantization/w8a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,12 @@ def apply(
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
if enable_force_load_balance:
topk_ids = torch.randint_like(
topk_ids, 0, global_num_experts - global_redundant_expert_num)
random_matrix = torch.rand(topk_ids.size(0),
global_num_experts -
global_redundant_expert_num,
device=topk_ids.device)
topk_ids = torch.argsort(
random_matrix, dim=1)[:, :topk_ids.size(1)].to(topk_ids.dtype)

topk_weights = topk_weights.to(self.in_dtype)

Expand Down
11 changes: 11 additions & 0 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,17 @@ def _sync_metadata_across_dp(
# immediately once the other two flags are no longer needed.
if self.dp_size == 1:
return num_tokens, None, with_prefill
# NOTE: Here we can skip the all_reduce operation and avoid paading tokens
# to max_tokens_acrodd_dp in D nodes. In MoE models, we must ensure that
# num_tokens DOES NOT exceed mc2_tokens_capacity which means that moe_comm_method
# of each rank is MC2. It is recommended to enable recompute scheduler for D Nodes.
if self.is_kv_consumer and not self.in_profile_run:
num_tokens_after_padding = torch.tensor([num_tokens] *
self.dp_size,
device="cpu",
dtype=torch.int32)
return num_tokens, num_tokens_after_padding, with_prefill

# Sync num_tokens, with_prefill across dp ranks
num_tokens_tensor = torch.tensor([
num_tokens if i == self.dp_rank else 0 for i in range(self.dp_size)
Expand Down
Loading