From d78b60c6d8e5f7ac2c43e25098c6544e6d231cc3 Mon Sep 17 00:00:00 2001 From: xiaoxiaohehe001 Date: Thu, 20 Nov 2025 15:09:59 +0800 Subject: [PATCH] support_moe_quant --- fastdeploy/config.py | 1 + fastdeploy/engine/args_utils.py | 1 + fastdeploy/entrypoints/engine_client.py | 50 +++---- fastdeploy/eplb/experts_manager.py | 65 ++++---- fastdeploy/eplb/utils.py | 140 +++++++++--------- fastdeploy/model_executor/layers/moe/moe.py | 2 +- .../model_executor/models/ernie4_5_moe.py | 3 +- .../models/ernie4_5_vl/ernie4_5_vl_moe.py | 3 +- fastdeploy/worker/worker_process.py | 59 ++++---- 9 files changed, 172 insertions(+), 152 deletions(-) diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 4e549df148f..b6d8e73e263 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -121,6 +121,7 @@ def __init__( ): self.model = "" self.is_quantized = False + self.is_moe_quantized = False self.max_model_len = 0 self.dtype = "" self.enable_logprob = False diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index c0b4e3a01a2..fe4a4ce7614 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -1076,6 +1076,7 @@ def create_eplb_config(self) -> EPLBConfig: if self.eplb_config is not None: for k, v in self.eplb_config.items(): eplb_args[k] = v + eplb_args["enable_eplb"] = self.enable_eplb return EPLBConfig(eplb_args) def create_engine_config(self, port_availability_check: bool = True) -> FDConfig: diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py index e1a359624cd..ad28318f3da 100644 --- a/fastdeploy/entrypoints/engine_client.py +++ b/fastdeploy/entrypoints/engine_client.py @@ -96,7 +96,7 @@ def __init__( else: self.is_master = False - if self.config.eplb_config.enable_eplb and self.config.parallel_config.expert_parallel_rank == 0: + if self.config.eplb_config.enable_eplb: self.init_eplb_signals(ipc_signal_suffix=port) array_size = min(max_chips_per_node, tensor_parallel_size) @@ -126,17 +126,21 @@ def init_eplb_signals(self, ipc_signal_suffix): """ Initialize eplb signals. """ + if self.config.parallel_config.tensor_parallel_rank != 0: + # only TP rank 0 need to init eplb signals, rank 0 manage all EPLB signals for all TP ranks + return self.signal_clear_experts_token_stats_list = [] self.local_experts_token_stats_array_list = [] self.expert_tokens_stats_array_list = [] self.signal_update_weight_from_disk_array_list = [] self.update_weight_from_disk_result_list = [] + dp_ipc_signal_suffix = f"{ipc_signal_suffix}_dp{self.config.parallel_config.local_data_parallel_id}" rearrange_experts_status = np.zeros([1], dtype=np.int32) self.rearrange_experts_signal = IPCSignal( name="rearrange_experts_status", array=rearrange_experts_status, dtype=np.int32, - suffix=ipc_signal_suffix, + suffix=dp_ipc_signal_suffix, create=False, ) @@ -145,14 +149,14 @@ def init_eplb_signals(self, ipc_signal_suffix): name="rearrange_experts_ips_size", array=rearrange_experts_ips_size_array, dtype=np.int32, - suffix=ipc_signal_suffix, + suffix=dp_ipc_signal_suffix, create=False, ) self.shm_rearrange_experts_ips_list = IPCSignal( name="rearrange_experts_ips_list", shm_size=self.config.eplb_config.redundant_expert_ip_shm_size, - suffix=ipc_signal_suffix, + suffix=dp_ipc_signal_suffix, create=False, ) @@ -161,27 +165,19 @@ def init_eplb_signals(self, ipc_signal_suffix): name="signal_update_weight_from_tensor", array=signal_update_weight_from_tensor, dtype=np.int32, - suffix=ipc_signal_suffix, + suffix=dp_ipc_signal_suffix, create=False, ) - if envs.FD_ENABLE_MULTI_API_SERVER: - engine_worker_suffix = [ - self.config.parallel_config.engine_worker_queue_port[ - self.config.parallel_config.local_data_parallel_id - ] - ] - else: - engine_worker_suffix = self.config.parallel_config.engine_worker_queue_port - - for suffix_port in engine_worker_suffix: + for tp_rank_id in range(self.config.parallel_config.tensor_parallel_size): + tp_ipc_signal_suffix = f"{dp_ipc_signal_suffix}_tp{tp_rank_id}" signal_clear_experts_token_stats = np.zeros([1], dtype=np.int32) self.signal_clear_experts_token_stats_list.append( IPCSignal( name="signal_clear_experts_token_stats", array=signal_clear_experts_token_stats, dtype=np.int32, - suffix=suffix_port, + suffix=tp_ipc_signal_suffix, create=False, ) ) @@ -192,7 +188,7 @@ def init_eplb_signals(self, ipc_signal_suffix): name="signal_update_weight_from_disk", array=signal_update_weight_from_disk, dtype=np.int32, - suffix=suffix_port, + suffix=tp_ipc_signal_suffix, create=False, ) ) @@ -203,7 +199,7 @@ def init_eplb_signals(self, ipc_signal_suffix): name="result_update_weight_from_disk", array=result_update_weight_from_disk, dtype=np.int32, - suffix=suffix_port, + suffix=tp_ipc_signal_suffix, create=False, ) ) @@ -217,7 +213,7 @@ def init_eplb_signals(self, ipc_signal_suffix): name="all_experts_token_stats", array=experts_token_stats, dtype=np.int32, - suffix=suffix_port, + suffix=tp_ipc_signal_suffix, create=False, ) ) @@ -226,7 +222,7 @@ def init_eplb_signals(self, ipc_signal_suffix): name="local_experts_token_stats", array=experts_token_stats, dtype=np.int32, - suffix=suffix_port, + suffix=tp_ipc_signal_suffix, create=False, ) ) @@ -541,10 +537,10 @@ async def rearrange_experts(self, request_dict: dict): status_code = HTTPStatus.UNAUTHORIZED return content, status_code - if self.config.parallel_config.expert_parallel_rank != 0: + if self.config.parallel_config.tensor_parallel_rank != 0: content = { "code": 1, - "msg": f"actual rank {self.config.parallel_config.expert_parallel_rank}, expect rank 0", + "msg": f"actual rank {self.config.parallel_config.tensor_parallel_rank}, expect rank 0", } status_code = HTTPStatus.BAD_REQUEST return content, status_code @@ -589,6 +585,8 @@ async def rearrange_experts(self, request_dict: dict): status_code = HTTPStatus.BAD_REQUEST else: weight = np.array(request_dict["data"], dtype=np.int32) + api_server_logger.info(f"expert_tokens_stats_array_list: {weight}") + for idx in range(len(self.expert_tokens_stats_array_list)): self.expert_tokens_stats_array_list[idx].value[:] = weight[:] self.signal_update_weight_from_disk_array_list[idx].value[0] = 1 @@ -645,10 +643,10 @@ async def get_per_expert_tokens_stats(self, request_dict: dict): status_code = HTTPStatus.UNAUTHORIZED return content, status_code - if self.config.parallel_config.expert_parallel_rank != 0: + if self.config.parallel_config.tensor_parallel_rank != 0: content = { "code": 1, - "msg": f"actual rank {self.config.parallel_config.expert_parallel_rank}, expect rank 0", + "msg": f"actual rank {self.config.parallel_config.tensor_parallel_rank}, expect rank 0", } status_code = HTTPStatus.BAD_REQUEST return content, status_code @@ -688,10 +686,10 @@ async def check_redundant(self, request_dict: dict): status_code = HTTPStatus.UNAUTHORIZED return content, status_code - if self.config.parallel_config.expert_parallel_rank != 0: + if self.config.parallel_config.tensor_parallel_rank != 0: content = { "code": 1, - "msg": f"actual rank {self.config.parallel_config.expert_parallel_rank}, expect rank 0", + "msg": f"actual rank {self.config.parallel_config.tensor_parallel_rank}, expect rank 0", } status_code = HTTPStatus.BAD_REQUEST return content, status_code diff --git a/fastdeploy/eplb/experts_manager.py b/fastdeploy/eplb/experts_manager.py index d2af1a1b168..6b79c9a776e 100644 --- a/fastdeploy/eplb/experts_manager.py +++ b/fastdeploy/eplb/experts_manager.py @@ -38,7 +38,7 @@ class RedundantExpertManager: def __init__( self, rank: int = 0, - ep_size: int = 32, + ep_size: int = 64, fd_config: FDConfig = None, ipc_signal_suffix: int = 0, ): @@ -54,6 +54,7 @@ def __init__( self.num_hidden_layers = self.fd_config.model_config.num_hidden_layers self.num_logical_experts = self.fd_config.model_config.moe_num_experts self.ipc_signal_suffix = ipc_signal_suffix + self.local_rank = self.rank % self.fd_config.parallel_config.tensor_parallel_size self.num_replicas = self.num_logical_experts + self.num_redundant_experts self.num_groups = self.num_logical_experts @@ -171,20 +172,21 @@ def listen_rearrange_expert_signal(self): """ listen_rearrange_expert_signal """ - if self.rank == 0: + dp_ipc_signal_suffix = f"{self.ipc_signal_suffix}_dp{self.fd_config.parallel_config.local_data_parallel_id}" + if self.local_rank == 0: rearrange_experts_ips_size_array = np.zeros([1], dtype=np.int32) rearrange_experts_ips_size_signal = IPCSignal( name="rearrange_experts_ips_size", array=rearrange_experts_ips_size_array, dtype=np.int32, - suffix=self.ipc_signal_suffix, + suffix=dp_ipc_signal_suffix, create=False, ) shm_rearrange_experts_ips_list = IPCSignal( name="rearrange_experts_ips_list", shm_size=self.eplb_config.redundant_expert_ip_shm_size, - suffix=self.ipc_signal_suffix, + suffix=dp_ipc_signal_suffix, create=False, ) @@ -193,16 +195,25 @@ def listen_rearrange_expert_signal(self): name="rearrange_experts_status", array=rearrange_experts_status, dtype=np.int32, - suffix=self.ipc_signal_suffix, + suffix=dp_ipc_signal_suffix, + create=False, + ) + signal_update_weight_from_tensor = np.zeros([1], dtype=np.int32) + self.signal_update_weight_from_tensor_array = IPCSignal( + name="signal_update_weight_from_tensor", + array=signal_update_weight_from_tensor, + dtype=np.int32, + suffix=dp_ipc_signal_suffix, create=False, ) + tp_ipc_signal_suffix = f"{dp_ipc_signal_suffix}_tp{self.local_rank}" signal_update_weight_from_disk = np.zeros([1], dtype=np.int32) signal_update_weight_from_disk_array = IPCSignal( name="signal_update_weight_from_disk", array=signal_update_weight_from_disk, dtype=np.int32, - suffix=self.ipc_signal_suffix, + suffix=tp_ipc_signal_suffix, create=False, ) @@ -214,12 +225,21 @@ def listen_rearrange_expert_signal(self): name="all_experts_token_stats", array=experts_token_stats, dtype=np.int32, - suffix=self.ipc_signal_suffix, + suffix=tp_ipc_signal_suffix, + create=False, + ) + + result_update_weight_from_disk = np.zeros([1], dtype=np.int32) + self.update_weight_from_disk_result = IPCSignal( + name="result_update_weight_from_disk", + array=result_update_weight_from_disk, + dtype=np.int32, + suffix=tp_ipc_signal_suffix, create=False, ) while True: - if self.rank == 0: + if self.local_rank == 0: now = int(time.time()) if rearrange_experts_ips_size_signal.value[0] > 0: # step 1. all reduce experts token stats @@ -267,8 +287,8 @@ def caculate_expert_rank_table(self, is_init=False): eplb_strategy = self.eplb_config.redundant_expert_eplb_strategy if is_init: num_groups = 1 - num_nodes = 2 - num_gpus = 2 * 8 + num_nodes = 8 + num_gpus = 8 * 8 eplb_strategy = "" # eplb rank_expert_list, logical_to_physical_map, expert_count = rebalance_experts( @@ -291,7 +311,7 @@ def caculate_expert_rank_table(self, is_init=False): self.model_expert_id_to_ep_rank_array[..., : logical_to_physical_map.shape[-1]] = logical_to_physical_map[:] self.model_expert_in_rank_num_list[:] = expert_count[:] - if self.rank == 0: + if self.local_rank == 0: workload = RedundantExpertWorkload() workload.tokens_per_expert_stats_list = self.model_tokens_per_expert_stats_list.tolist() workload.ep_rank_to_expert_id_list = rank_expert_list.tolist() @@ -304,16 +324,7 @@ def update_weight_from_disk(self): update_weight_from_disk """ begin_time = time.time() - result_update_weight_from_disk = np.zeros([1], dtype=np.int32) - update_weight_from_disk_result = IPCSignal( - name="result_update_weight_from_disk", - array=result_update_weight_from_disk, - dtype=np.int32, - suffix=self.ipc_signal_suffix, - create=False, - ) - update_weight_from_disk_result.value[0] = 0 - + self.update_weight_from_disk_result.value[0] = 0 self.logger.info(f"redundant_expert: update_weight_from_disk send to async process, rank {self.rank}") self.parent_mg_conn.send( { @@ -326,7 +337,7 @@ def update_weight_from_disk(self): self.tensor_infos = response["weights"] # 更新权重加载结果 - update_weight_from_disk_result.value[0] = 1 if response["result"] else -1 + self.update_weight_from_disk_result.value[0] = 1 if response["result"] else -1 self.logger.info( "redundant_expert: update_weight_from_disk end, rank" + f" {self.rank} {response['result']}, cost {int(time.time() - begin_time)}s" @@ -441,15 +452,7 @@ def allreduce_load_weight_result(self): or not self.eplb_config.redundant_expert_enable_schedule_cordon ): self.logger.info("redundant_expert: allreduce_load_weight_result success, notify infer.py") - signal_update_weight_from_tensor = np.zeros([1], dtype=np.int32) - signal_update_weight_from_tensor_array = IPCSignal( - name="signal_update_weight_from_tensor", - array=signal_update_weight_from_tensor, - dtype=np.int32, - suffix=self.ipc_signal_suffix, - create=False, - ) - signal_update_weight_from_tensor_array.value[0] = 1 + self.signal_update_weight_from_tensor_array.value[0] = 1 return True def allgather_load_weight_result(self): diff --git a/fastdeploy/eplb/utils.py b/fastdeploy/eplb/utils.py index a4691b6fd3e..22bd68193ba 100644 --- a/fastdeploy/eplb/utils.py +++ b/fastdeploy/eplb/utils.py @@ -69,92 +69,98 @@ def init_eplb_signals(config: FDConfig, ipc_signal_suffix): """ Initialize shared memory to indicate eplb status """ - if config.parallel_config.local_data_parallel_id == 0: - # rearrange_experts_status Record the expert's rearrangement status - rearrange_experts_array = np.zeros([1], dtype=np.int32) - _ = IPCSignal( - name="rearrange_experts_status", - array=rearrange_experts_array, - dtype=np.int32, - suffix=ipc_signal_suffix, - create=True, - ) - - # Record all DP rank IPs when receiving expert rearrangement requests - rearrange_experts_ips_size_array = np.zeros([1], dtype=np.int32) - _ = IPCSignal( - name="rearrange_experts_ips_size", - array=rearrange_experts_ips_size_array, - dtype=np.int32, - suffix=ipc_signal_suffix, - create=True, - ) - _ = IPCSignal( - name="rearrange_experts_ips_list", - shm_size=config.eplb_config.redundant_expert_ip_shm_size, - suffix=ipc_signal_suffix, - create=True, - ) - - # Receive signals for updating weights - signal_update_weight_from_tensor = np.zeros([1], dtype=np.int32) - _ = IPCSignal( - name="signal_update_weight_from_tensor", - array=signal_update_weight_from_tensor, - dtype=np.int32, - suffix=ipc_signal_suffix, - create=True, - ) + if config.parallel_config.tensor_parallel_rank != 0: + # only TP rank 0 need to init eplb signals, rank 0 manage all EPLB signals for all TP ranks + return - # Record expert workload - experts_token_stats = np.zeros( - (config.model_config.num_hidden_layers, config.model_config.moe_num_experts), - dtype=np.int32, - ) + dp_ipc_signal_suffix = f"{ipc_signal_suffix}_dp{config.parallel_config.local_data_parallel_id}" + # rearrange_experts_status Record the expert's rearrangement status + rearrange_experts_array = np.zeros([1], dtype=np.int32) _ = IPCSignal( - name="all_experts_token_stats", - array=experts_token_stats, + name="rearrange_experts_status", + array=rearrange_experts_array, dtype=np.int32, - suffix=ipc_signal_suffix, - create=True, - ) - _ = IPCSignal( - name="local_experts_token_stats", - array=experts_token_stats, - dtype=np.int32, - suffix=ipc_signal_suffix, + suffix=dp_ipc_signal_suffix, create=True, ) - # Receive signals for loading weights - signal_update_weight_from_disk = np.zeros([1], dtype=np.int32) + # Record all DP rank IPs when receiving expert rearrangement requests + rearrange_experts_ips_size_array = np.zeros([1], dtype=np.int32) _ = IPCSignal( - name="signal_update_weight_from_disk", - array=signal_update_weight_from_disk, + name="rearrange_experts_ips_size", + array=rearrange_experts_ips_size_array, dtype=np.int32, - suffix=ipc_signal_suffix, + suffix=dp_ipc_signal_suffix, create=True, ) - - # Receive signals for clearing expert loads - clear_experts_token_stats = np.zeros([1], dtype=np.int32) _ = IPCSignal( - name="signal_clear_experts_token_stats", - array=clear_experts_token_stats, - dtype=np.int32, - suffix=ipc_signal_suffix, + name="rearrange_experts_ips_list", + shm_size=config.eplb_config.redundant_expert_ip_shm_size, + suffix=dp_ipc_signal_suffix, create=True, ) - result_update_weight_from_disk = np.zeros([1], dtype=np.int32) + # Receive signals for updating weights + signal_update_weight_from_tensor = np.zeros([1], dtype=np.int32) _ = IPCSignal( - name="result_update_weight_from_disk", - array=result_update_weight_from_disk, + name="signal_update_weight_from_tensor", + array=signal_update_weight_from_tensor, dtype=np.int32, - suffix=ipc_signal_suffix, + suffix=dp_ipc_signal_suffix, create=True, ) + for rank_id in range(config.parallel_config.tensor_parallel_size): + tp_ipc_signal_suffix = f"{dp_ipc_signal_suffix}_tp{rank_id}" + # Record expert workload + experts_token_stats = np.zeros( + (config.model_config.num_hidden_layers, config.model_config.moe_num_experts), + dtype=np.int32, + ) + _ = IPCSignal( + name="all_experts_token_stats", + array=experts_token_stats, + dtype=np.int32, + suffix=tp_ipc_signal_suffix, + create=True, + ) + _ = IPCSignal( + name="local_experts_token_stats", + array=experts_token_stats, + dtype=np.int32, + suffix=tp_ipc_signal_suffix, + create=True, + ) + + # Receive signals for loading weights + signal_update_weight_from_disk = np.zeros([1], dtype=np.int32) + _ = IPCSignal( + name="signal_update_weight_from_disk", + array=signal_update_weight_from_disk, + dtype=np.int32, + suffix=tp_ipc_signal_suffix, + create=True, + ) + + # Receive signals for clearing expert loads + clear_experts_token_stats = np.zeros([1], dtype=np.int32) + _ = IPCSignal( + name="signal_clear_experts_token_stats", + array=clear_experts_token_stats, + dtype=np.int32, + suffix=tp_ipc_signal_suffix, + create=True, + ) + + result_update_weight_from_disk = np.zeros([1], dtype=np.int32) + _ = IPCSignal( + name="result_update_weight_from_disk", + array=result_update_weight_from_disk, + dtype=np.int32, + suffix=tp_ipc_signal_suffix, + create=True, + ) + if __name__ == "__main__": print(RedundantExpertWorkload("/tmp").load()) diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index 8b83aeccabe..35078def45b 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -524,7 +524,7 @@ def load_state_dict(self, state_dict, is_rearrange: bool = False): """ load_state_dict function. """ - if self.fd_config.model_config.is_quantized: + if self.fd_config.model_config.is_quantized or self.fd_config.model_config.is_moe_quantized: if getattr(self.fd_config.quant_config, "is_permuted", True): self.quant_method.process_prequanted_weights(self, state_dict, is_rearrange) else: diff --git a/fastdeploy/model_executor/models/ernie4_5_moe.py b/fastdeploy/model_executor/models/ernie4_5_moe.py index 9e02a44ef2c..6db00304632 100644 --- a/fastdeploy/model_executor/models/ernie4_5_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_moe.py @@ -141,7 +141,8 @@ def __init__( "down_proj_expert_code_zp_key": f"{prefix}.experts.{{}}.down_proj.code_zp", } elif moe_quant_type == "tensor_wise_fp8" or ( - moe_quant_type == "block_wise_fp8" and fd_config.model_config.is_quantized + moe_quant_type == "block_wise_fp8" + and (fd_config.model_config.is_quantized or fd_config.model_config.is_moe_quantized) ): weight_key_map = { "gate_weight_key": f"{prefix}.gate.weight", diff --git a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py index fff67033414..e7b3dea89eb 100644 --- a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py @@ -93,7 +93,8 @@ def __init__( moe_quant_type = fd_config.quant_config.moe_quant_type if moe_quant_type == "tensor_wise_fp8" or ( - moe_quant_type == "block_wise_fp8" and fd_config.model_config.is_quantized + moe_quant_type == "block_wise_fp8" + and (fd_config.model_config.is_quantized or fd_config.model_config.is_moe_quantized) ): weight_key_map = { "gate_correction_bias_key": f"{prefix}.moe_statics.e_score_correction_bias", diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 84ed49948e2..ec54e6bbb48 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -287,6 +287,7 @@ def event_loop_normal(self) -> None: """Main event loop for Paddle Distrubuted Workers. TODO(gongshaotian): support remote calling of functions that control worker. """ + local_rank = self.local_rank % self.parallel_config.tensor_parallel_size if self.eplb_config.enable_eplb: self.last_dump_expert_workload_ts = 0 self.experts_manager = RedundantExpertManager( @@ -295,6 +296,29 @@ def event_loop_normal(self) -> None: fd_config=self.fd_config, ipc_signal_suffix=self.parallel_config.engine_worker_queue_port, ) + dp_ipc_signal_suffix = ( + f"{self.parallel_config.engine_worker_queue_port}_dp{self.parallel_config.local_data_parallel_id}" + ) + if local_rank == 0: + signal_update_weight_from_tensor = np.zeros([1], dtype=np.int32) + signal_update_weight_from_tensor_array = IPCSignal( + name="signal_update_weight_from_tensor", + array=signal_update_weight_from_tensor, + dtype=np.int32, + suffix=dp_ipc_signal_suffix, + create=False, + ) + + rearrange_experts_status = np.zeros([1], dtype=np.int32) + rearrange_experts_signal = IPCSignal( + name="rearrange_experts_status", + array=rearrange_experts_status, + dtype=np.int32, + suffix=dp_ipc_signal_suffix, + create=False, + ) + + tp_ipc_signal_suffix = f"{dp_ipc_signal_suffix}_tp{local_rank}" experts_token_stats = np.zeros( (self.fd_config.model_config.num_hidden_layers, self.fd_config.model_config.moe_num_experts), dtype=np.int32, @@ -303,7 +327,7 @@ def event_loop_normal(self) -> None: name="local_experts_token_stats", array=experts_token_stats, dtype=np.int32, - suffix=self.parallel_config.engine_worker_queue_port, + suffix=tp_ipc_signal_suffix, create=False, ) @@ -312,29 +336,10 @@ def event_loop_normal(self) -> None: name="signal_clear_experts_token_stats", array=clear_experts_token_stats, dtype=np.int32, - suffix=self.parallel_config.engine_worker_queue_port, + suffix=tp_ipc_signal_suffix, create=False, ) - if self.local_rank == 0: - signal_update_weight_from_tensor = np.zeros([1], dtype=np.int32) - signal_update_weight_from_tensor_array = IPCSignal( - name="signal_update_weight_from_tensor", - array=signal_update_weight_from_tensor, - dtype=np.int32, - suffix=self.parallel_config.engine_worker_queue_port, - create=False, - ) - - rearrange_experts_status = np.zeros([1], dtype=np.int32) - rearrange_experts_signal = IPCSignal( - name="rearrange_experts_status", - array=rearrange_experts_status, - dtype=np.int32, - suffix=self.parallel_config.engine_worker_queue_port, - create=False, - ) - mmap_infos = create_mmap( [MODEL_MAIN_NAME], self.local_rank, @@ -348,7 +353,6 @@ def event_loop_normal(self) -> None: self.nnode = int((self.parallel_config.tensor_parallel_size + 7) // 8) req_ids = [] num_running_requests = 0 - local_rank = self.local_rank % self.parallel_config.tensor_parallel_size self.model_weights_signal = np.zeros([1], dtype=np.int32) attention_dp_cached_prefill_tasks = [] attention_dp_wait_prefill_iters = 0 @@ -379,7 +383,7 @@ def event_loop_normal(self) -> None: # 所有DP同步更新权重 broadcast_value = 0 - if self.local_rank == 0 and signal_update_weight_from_tensor_array.value[0] == 1: + if local_rank == 0 and signal_update_weight_from_tensor_array.value[0] == 1: logger.info("redundant_expert: update_weight_from_tensor broadcast signal") signal_update_weight_from_tensor_array.value[0] = 0 broadcast_value = REARRANGE_EXPERT_MAGIC_NUM @@ -391,7 +395,7 @@ def event_loop_normal(self) -> None: f"redundant_expert: update_weight_from_tensor success, cost {(time.time() - rearrange_time)*1000}ms" ) paddle.distributed.barrier() - if self.local_rank == 0: + if local_rank == 0: rearrange_experts_signal.value[0] = RearrangeExpertStatus.DONE.value logger.info("redundant_expert: done") if local_rank == 0: @@ -919,8 +923,13 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: if quantization_config is not None: if "is_quantized" in quantization_config: model_config.is_quantized = quantization_config["is_quantized"] + elif "is_moe_quantized" in quantization_config: + model_config.is_moe_quantized = quantization_config["is_moe_quantized"] elif "kv_cache_quant_type" not in quantization_config: - model_config.is_quantized = True + if "is_moe_quantized" not in quantization_config: + model_config.is_quantized = True + else: + model_config.is_moe_quantized = True quant_config_name = None if quantization_config is not None and quantization_config.get("quantization", None) is None: