Skip to content

Commit 5b52a51

Browse files
author
ddchenhao66
committed
[XPU] xpu support PD disaggregation in v1 scheduler
1 parent 77d25ba commit 5b52a51

File tree

8 files changed

+49
-19
lines changed

8 files changed

+49
-19
lines changed

custom_ops/xpu_ops/src/ops/remote_cache_kv_ipc.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,14 @@ struct RemoteCacheKvIpc {
7272
}
7373

7474
void send_signal() {
75-
msg_sed.mtext[1] = layer_id_;
76-
if ((msgsnd(msgid, &msg_sed, (MAX_BSZ * 3 + 2) * 4, 0)) == -1) {
77-
printf("kv signal full msg buffer\n");
75+
if (inited) {
76+
msg_sed.mtext[1] = layer_id_;
77+
if ((msgsnd(msgid, &msg_sed, (MAX_BSZ * 3 + 2) * 4, 0)) == -1) {
78+
printf("kv signal full msg buffer\n");
79+
}
80+
layer_id_ = (layer_id_ + 1);
81+
assert(layer_id_ <= num_layers_);
7882
}
79-
layer_id_ = (layer_id_ + 1);
80-
assert(layer_id_ <= num_layers_);
8183
}
8284
};
8385

fastdeploy/cache_manager/cache_messager.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,6 @@ def __init__(
161161
for layer_idx in range(self.num_layers):
162162
key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
163163
val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
164-
logger.info(
165-
f"[key_cache: {hex(key_cache.data_ptr())}],[key_cache_mem: {hex(get_peer_mem_addr(key_cache.data_ptr()))}]"
166-
)
167164
cache_k.append(key_cache)
168165
cache_v.append(val_cache)
169166
if paddle.is_compiled_with_xpu():
@@ -465,8 +462,12 @@ def __init__(
465462
val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
466463
cache_k.append(key_cache)
467464
cache_v.append(val_cache)
468-
cache_k_ptr_list.append(key_cache.data_ptr())
469-
cache_v_ptr_list.append(val_cache.data_ptr())
465+
if paddle.is_compiled_with_xpu():
466+
cache_k_ptr_list.append(get_peer_mem_addr(key_cache.data_ptr()))
467+
cache_v_ptr_list.append(get_peer_mem_addr(val_cache.data_ptr()))
468+
else:
469+
cache_k_ptr_list.append(key_cache.data_ptr())
470+
cache_v_ptr_list.append(val_cache.data_ptr())
470471
cache_k_ptr_list = np.array(cache_k_ptr_list)
471472
cache_v_ptr_list = np.array(cache_v_ptr_list)
472473

@@ -771,7 +772,7 @@ def _handle_connect_task(self):
771772
def main():
772773
device = args.device_id
773774
rank = args.rank
774-
set_device(args.rank)
775+
set_device(device)
775776
cache_type = args.cache_dtype
776777
speculative_config = SpeculativeConfig(args.speculative_config)
777778
num_extra_layers = speculative_config.num_extra_cache_layer
@@ -883,7 +884,6 @@ def main():
883884
args = parse_args()
884885
rank_id = args.rank + args.local_data_parallel_id * args.mp_num
885886
logger = get_logger("cache_messager", f"cache_messager_rank{rank_id}.log")
886-
887887
logger.info("create cache messager...")
888888
logger.info(f"{args}")
889889
main()

fastdeploy/cache_manager/cache_transfer_manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def _init_gpu_cache(self, args):
204204
logger.info(f"[rank {self.rank}/{self.n_ranks}] OK! Stop waiting.")
205205

206206
logger.info(f"[rank {self.rank}/{self.n_ranks}] Initializing kv cache for all layers.")
207-
set_device(self.rank)
207+
set_device(self.device)
208208
for i in range(args.num_layers + self.num_extra_layers):
209209
num_gpu_blocks = self.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks
210210
key_name = f"key_caches_{i}_rank{self.rank}.device{self.device}"
@@ -569,7 +569,7 @@ def clear_or_update_caches(self, args):
569569
time.sleep(0.1)
570570

571571
# clear gpu caches
572-
set_device(self.rank)
572+
set_device(self.device)
573573
for name, tensor in self.gpu_cache_kvs.items():
574574
unset_data_ipc(tensor, name, True, False)
575575
self.gpu_cache_kvs.clear()
@@ -640,5 +640,5 @@ def main():
640640
args = parse_args()
641641
rank_id = args.rank + args.local_data_parallel_id * args.mp_num
642642
logger = get_logger("cache_transfer_manager", f"cache_transfer_manager_rank{rank_id}.log")
643-
set_device(rank_id)
643+
set_device(args.device_id)
644644
main()

fastdeploy/cache_manager/ops.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,13 @@ def share_external_data_(cache, cache_name, cache_shape, use_ipc):
6767
return cache
6868

6969

70+
def get_all_visible_devices():
71+
if current_platform.is_xpu():
72+
return "XPU_VISIBLE_DEVICES=0,1,2,3,4,5,6,7"
73+
else:
74+
return "CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7"
75+
76+
7077
__all__ = [
7178
"cuda_host_alloc",
7279
"cuda_host_free",
@@ -81,4 +88,5 @@ def share_external_data_(cache, cache_name, cache_shape, use_ipc):
8188
"ipc_sent_key_value_cache_by_remote_ptr",
8289
"ipc_sent_key_value_cache_by_remote_ptr_block_sync",
8390
"get_peer_mem_addr",
91+
"get_all_visible_devices",
8492
]

fastdeploy/cache_manager/prefix_cache_manager.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from fastdeploy import envs
3434
from fastdeploy.cache_manager.cache_data import BlockNode, CacheStatus
3535
from fastdeploy.cache_manager.cache_metrics import CacheMetrics
36+
from fastdeploy.cache_manager.ops import get_all_visible_devices
3637
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, PrefixTreeStatus
3738
from fastdeploy.metrics.metrics import main_process_metrics
3839
from fastdeploy.utils import get_logger
@@ -243,9 +244,11 @@ def launch_cache_manager(
243244
# Run command to launch cache transfer managers
244245
log_dir = envs.FD_LOG_DIR
245246
cache_manager_processes = []
247+
visible_devices = get_all_visible_devices()
246248
for i in range(tensor_parallel_size):
247249
launch_cmd = (
248-
"FLAGS_allocator_strategy=auto_growth CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7"
250+
"FLAGS_allocator_strategy=auto_growth "
251+
+ visible_devices
249252
+ " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0"
250253
+ f" FD_ENABLE_SWAP_SPACE_CLEARING={envs.FD_ENABLE_SWAP_SPACE_CLEARING}"
251254
+ f" {sys.executable} {py_path}"
@@ -328,9 +331,11 @@ def launch_cache_messager(
328331
py_path = os.path.join(current_dir_path, filename)
329332
log_dir = envs.FD_LOG_DIR
330333
cache_messager_processes = []
334+
visible_devices = get_all_visible_devices()
331335
for i in range(tensor_parallel_size):
332336
launch_cmd = (
333-
"FLAGS_allocator_strategy=auto_growth CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7"
337+
"FLAGS_allocator_strategy=auto_growth "
338+
+ visible_devices
334339
+ " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0"
335340
+ f" {sys.executable} {py_path}"
336341
+ f" --device_id {int(device_ids[i])}"

fastdeploy/model_executor/layers/attention/ops/init_kv_signal_per_query.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,17 @@ def init_kv_signal_per_query(
3232
if current_platform.is_cuda():
3333
from fastdeploy.model_executor.ops.gpu import init_kv_signal_per_query
3434

35+
out = init_kv_signal_per_query(
36+
seq_lens_encoder,
37+
seq_lens_this_time,
38+
seq_lens_decoder,
39+
rank,
40+
num_layers,
41+
)
42+
return out
43+
elif current_platform.is_xpu():
44+
from fastdeploy.model_executor.ops.xpu import init_kv_signal_per_query
45+
3546
out = init_kv_signal_per_query(
3647
seq_lens_encoder,
3748
seq_lens_this_time,

fastdeploy/model_executor/layers/attention/xpu_attn_backend.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def __init__(
9191
)
9292
self.causal: bool = getattr(fd_config.model_config, "causal", True)
9393
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
94+
self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"])
9495

9596
self.kv_num_heads: int = kv_num_heads
9697
self.num_heads: int = num_heads
@@ -122,7 +123,7 @@ def init_attention_metadata(self, forward_meta: ForwardMeta):
122123

123124
# pd_disaggregation
124125
metadata.kv_signal_data_list = [None] * self.num_layers
125-
if self.pd_disaggregation_mode == "per_chunk":
126+
if self.pd_disaggregation_mode == "per_chunk" and not forward_meta.is_profiling:
126127
if not self.keep_pd_step_flag:
127128
init_kv_signal_per_query(
128129
forward_meta.seq_lens_encoder,

fastdeploy/worker/xpu_model_runner.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def xpu_pre_process(
7171
draft_tokens: Optional[paddle.Tensor] = None,
7272
seq_lens_encoder: Optional[paddle.Tensor] = None,
7373
seq_lens_decoder: Optional[paddle.Tensor] = None,
74+
is_profiling: bool = False,
7475
) -> XPUForwardMeta:
7576
""" """
7677
max_len = input_ids.shape[1]
@@ -155,6 +156,8 @@ def xpu_pre_process(
155156

156157
share_inputs["ids_remove_padding"] = adjusted_input
157158
xpu_forward_meta.ids_remove_padding = adjusted_input
159+
# Set forward_meta.is_profiling to True to skip init_kv_signal_per_query for attention backends
160+
xpu_forward_meta.is_profiling = is_profiling
158161
return xpu_forward_meta
159162

160163

@@ -924,6 +927,7 @@ def _prepare_inputs(self, is_dummy_run=False) -> None:
924927
draft_tokens=None,
925928
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
926929
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
930+
is_profiling=is_dummy_run,
927931
)
928932
# Update bad tokens len
929933
max_bad_tokens_len = paddle.max(self.share_inputs["bad_tokens_len"])
@@ -1188,7 +1192,6 @@ class at the server level, which is too granular for ModelRunner.
11881192
self.kv_signal_sender = create_kv_signal_sender()
11891193
# 1. Prepare inputs of model and decoder.
11901194
self._prepare_inputs(is_dummy_run=is_dummy_run)
1191-
11921195
# NOTE(wufeisheng): If `not_need_stop`` is False, it means the current worker is in an idle state.
11931196
# This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode,
11941197
# when there is data on other runner, the current runner is required to execute part of the model.

0 commit comments

Comments
 (0)