Skip to content

Commit e70e227

Browse files
authored
[PD Disaggregation][XPU] Add XPU support for PD disaggregation (#5113)
* [XPU] xpu support PD disaggregation * [XPU] fix the issue of cache KV transfer process startup failure on non-zero XPU cards * [XPU] xpu support PD disaggregation in v1 scheduler --------- Co-authored-by: ddchenhao66 <dhaochen163.com>
1 parent 79f1833 commit e70e227

File tree

16 files changed

+274
-82
lines changed

16 files changed

+274
-82
lines changed
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "ops/remote_cache_kv_ipc.h"
16+
#include "paddle/extension.h"
17+
18+
#ifndef PD_BUILD_STATIC_OP
19+
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
20+
#endif
21+
22+
using cache_write_complete_signal_type =
23+
RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data;
24+
25+
paddle::Tensor InitSignalLayerwiseFunc(const paddle::Tensor& kv_signal_metadata,
26+
const int layer_id) {
27+
auto kv_signal_metadata_out =
28+
kv_signal_metadata.copy_to(paddle::CPUPlace(), false);
29+
kv_signal_metadata_out.data<int64_t>()[0] = static_cast<int64_t>(layer_id);
30+
return kv_signal_metadata_out;
31+
}
32+
33+
std::vector<paddle::Tensor> InitSignalLayerwise(
34+
const paddle::Tensor& kv_signal_metadata, const int layer_id) {
35+
return {InitSignalLayerwiseFunc(kv_signal_metadata, layer_id)};
36+
}
37+
38+
std::vector<std::vector<int64_t>> InitSignalLayerwiseShape(
39+
const std::vector<int64_t>& kv_signal_metadata_shape, const int layer_id) {
40+
return {kv_signal_metadata_shape};
41+
}
42+
43+
std::vector<paddle::DataType> InitSignalLayerwiseDtype(
44+
const paddle::DataType& kv_signal_metadata_dtype, const int layer_id) {
45+
return {paddle::DataType::INT64};
46+
}
47+
48+
PD_BUILD_STATIC_OP(init_signal_layerwise)
49+
.Inputs({"kv_signal_metadata"})
50+
.Outputs({"kv_signal_metadata_out"})
51+
.Attrs({"layer_id: int"})
52+
.SetKernelFn(PD_KERNEL(InitSignalLayerwise))
53+
.SetInferShapeFn(PD_INFER_SHAPE(InitSignalLayerwiseShape))
54+
.SetInferDtypeFn(PD_INFER_DTYPE(InitSignalLayerwiseDtype));

custom_ops/xpu_ops/src/ops/open_shm_and_get_meta_signal.cc

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,27 @@
1717
#include "ops/utility/env.h"
1818
#include "paddle/extension.h"
1919

20+
#ifndef PD_BUILD_STATIC_OP
21+
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
22+
#endif
23+
2024
XPU_DECLARE_BOOL(fmt_write_cache_completed_signal, false);
2125

2226
using cache_write_complete_signal_type =
2327
RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data;
2428

2529
paddle::Tensor OpenShmAndGetMetaSignalFunc(const int rank,
30+
const int device_id,
2631
const bool keep_pd_step_flag) {
2732
cache_write_complete_signal_type kv_signal_metadata;
28-
const char *fmt_write_cache_completed_signal_str =
33+
const char* fmt_write_cache_completed_signal_str =
2934
std::getenv("FLAGS_fmt_write_cache_completed_signal");
3035
if (fmt_write_cache_completed_signal_str &&
3136
(std::strcmp(fmt_write_cache_completed_signal_str, "true") == 0 ||
3237
std::strcmp(fmt_write_cache_completed_signal_str, "1") == 0)) {
3338
kv_signal_metadata =
3439
RemoteCacheKvIpc::open_shm_and_get_complete_signal_meta_data(
35-
rank, keep_pd_step_flag);
40+
rank, device_id, keep_pd_step_flag);
3641
}
3742

3843
auto kv_signal_metadata_out =
@@ -46,9 +51,9 @@ paddle::Tensor OpenShmAndGetMetaSignalFunc(const int rank,
4651
return kv_signal_metadata_out;
4752
}
4853

49-
void InitKVSignalPerQuery(const paddle::Tensor &seq_lens_encoder_tensor,
50-
const paddle::Tensor &seq_lens_this_time_tensor,
51-
const paddle::Tensor &seq_lens_decoder_tensor,
54+
void InitKVSignalPerQuery(const paddle::Tensor& seq_lens_encoder_tensor,
55+
const paddle::Tensor& seq_lens_this_time_tensor,
56+
const paddle::Tensor& seq_lens_decoder_tensor,
5257
const int rank,
5358
const int num_layers) {
5459
if (FLAGS_fmt_write_cache_completed_signal) {
@@ -68,24 +73,24 @@ void InitKVSignalPerQuery(const paddle::Tensor &seq_lens_encoder_tensor,
6873
}
6974

7075
std::vector<paddle::Tensor> OpenShmAndGetMetaSignal(
71-
const int rank, const bool keep_pd_step_flag) {
72-
return {OpenShmAndGetMetaSignalFunc(rank, keep_pd_step_flag)};
76+
const int rank, const int device_id, const bool keep_pd_step_flag) {
77+
return {OpenShmAndGetMetaSignalFunc(rank, device_id, keep_pd_step_flag)};
7378
}
7479

7580
std::vector<std::vector<int64_t>> OpenShmAndGetMetaSignalShape(
76-
const int rank, const bool keep_pd_step_flag) {
81+
const int rank, const int device_id, const bool keep_pd_step_flag) {
7782
return {{3}};
7883
}
7984

8085
std::vector<paddle::DataType> OpenShmAndGetMetaSignalDtype(
81-
const int rank, const bool keep_pd_step_flag) {
86+
const int rank, const int device_id, const bool keep_pd_step_flag) {
8287
return {paddle::DataType::INT64};
8388
}
8489

85-
PD_BUILD_OP(open_shm_and_get_meta_signal)
90+
PD_BUILD_STATIC_OP(open_shm_and_get_meta_signal)
8691
.Inputs({})
8792
.Outputs({"kv_signal_metadata"})
88-
.Attrs({"rank: int", "keep_pd_step_flag: bool"})
93+
.Attrs({"rank: int", "device_id: int", "keep_pd_step_flag: bool"})
8994
.SetKernelFn(PD_KERNEL(OpenShmAndGetMetaSignal))
9095
.SetInferShapeFn(PD_INFER_SHAPE(OpenShmAndGetMetaSignalShape))
9196
.SetInferDtypeFn(PD_INFER_DTYPE(OpenShmAndGetMetaSignalDtype));

custom_ops/xpu_ops/src/ops/remote_cache_kv_ipc.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ bool RemoteCacheKvIpc::kv_complete_signal_shmem_opened = false;
2626

2727
RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data
2828
RemoteCacheKvIpc::open_shm_and_get_complete_signal_meta_data(
29-
const int rank_id, const bool keep_pd_step_flag) {
29+
const int rank_id, const int device_id, const bool keep_pd_step_flag) {
3030
if (RemoteCacheKvIpc::kv_complete_signal_shmem_opened) {
3131
if (keep_pd_step_flag) {
3232
return RemoteCacheKvIpc::kv_complete_signal_meta_data;
@@ -47,12 +47,13 @@ RemoteCacheKvIpc::open_shm_and_get_complete_signal_meta_data(
4747
std::string iflags_server_uuid_env_str(iflags_server_uuid_env_p);
4848
flags_server_uuid = iflags_server_uuid_env_str;
4949
}
50+
5051
std::string step_shm_name =
51-
("splitwise_complete_prefilled_step_" + std::to_string(rank_id) + "_" +
52-
flags_server_uuid);
52+
("splitwise_complete_prefilled_step_" + std::to_string(rank_id) + "." +
53+
std::to_string(device_id));
5354
std::string layer_shm_name =
54-
("splitwise_complete_prefilled_layer_" + std::to_string(rank_id) + "_" +
55-
flags_server_uuid);
55+
("splitwise_complete_prefilled_layer_" + std::to_string(rank_id) + "." +
56+
std::to_string(device_id));
5657
if (const char* use_ep = std::getenv("ENABLE_EP_DP")) {
5758
if (std::strcmp(use_ep, "1") == 0) {
5859
step_shm_name = "splitwise_complete_prefilled_step_tprank0_dprank" +

custom_ops/xpu_ops/src/ops/remote_cache_kv_ipc.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ struct RemoteCacheKvIpc {
9393

9494
static RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data
9595
open_shm_and_get_complete_signal_meta_data(const int rank_id,
96+
const int device_id,
9697
const bool keep_pd_step_flag);
9798
static void save_cache_kv_complete_signal_layerwise(void* meta_data);
9899
static void save_cache_kv_complete_signal_layerwise_per_query(

custom_ops/xpu_ops/src/ops/share_external_data.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,26 +19,26 @@
1919
#include "xpu/plugin.h"
2020
#include "xpu_multiprocess.h" // NOLINT(build/include_subdir)
2121

22-
std::vector<paddle::Tensor> ShareExternalData(const paddle::Tensor &input,
22+
std::vector<paddle::Tensor> ShareExternalData(const paddle::Tensor& input,
2323
const std::string shm_name,
24-
const std::vector<int> &shape,
24+
const std::vector<int>& shape,
2525
bool use_ipc) {
2626
sharedMemoryInfo info;
2727
int ret = sharedMemoryOpen(shm_name.c_str(), sizeof(shmStruct), &info);
2828
PD_CHECK(ret == 0, "sharedMemoryOpen failed");
29-
volatile shmStruct *shm = static_cast<volatile shmStruct *>(info.addr);
30-
void *data_ptr_addr = nullptr;
29+
volatile shmStruct* shm = static_cast<volatile shmStruct*>(info.addr);
30+
void* data_ptr_addr = nullptr;
3131
if (use_ipc) {
3232
#if XPURT_VERSION_MAJOR == 5
3333
int ret = xpu_ipc_open_memhandle(&data_ptr_addr,
34-
*(XPUIpcMemHandle *)&shm->memHandle,
34+
*(XPUIpcMemHandle*)&shm->memHandle,
3535
0x01); // NOLINT
36-
PD_CHECK(ret == XPU_SUCCESS, "xpu_ipc_open_memhandle failed");
36+
PD_CHECK(ret == XPU_SUCCESS, shm_name, " xpu_ipc_open_memhandle failed");
3737
#elif XPURT_VERSION_MAJOR == 4
3838
PD_THROW("kl2 not support prefix cache");
3939
#endif
4040
} else {
41-
data_ptr_addr = reinterpret_cast<void *>(shm->data_ptr_addr);
41+
data_ptr_addr = reinterpret_cast<void*>(shm->data_ptr_addr);
4242
}
4343

4444
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());

fastdeploy/cache_manager/cache_messager.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,20 @@
2525
import numpy as np
2626
import paddle
2727

28+
from fastdeploy.cache_manager.ops import (
29+
get_output_kv_signal,
30+
get_peer_mem_addr,
31+
memory_allocated,
32+
set_data_ipc,
33+
set_device,
34+
)
2835
from fastdeploy.cache_manager.transfer_factory import IPCCommManager, RDMACommManager
2936
from fastdeploy.config import SpeculativeConfig
3037
from fastdeploy.inter_communicator import (
3138
EngineWorkerQueue,
3239
IPCSignal,
3340
shared_memory_exists,
3441
)
35-
from fastdeploy.model_executor.ops.gpu import get_output_kv_signal, set_data_ipc
3642
from fastdeploy.utils import envs, get_logger
3743

3844
logger = get_logger("cache_messager", "cache_messager.log")
@@ -157,16 +163,20 @@ def __init__(
157163
val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
158164
cache_k.append(key_cache)
159165
cache_v.append(val_cache)
160-
cache_k_ptr_list.append(key_cache.data_ptr())
161-
cache_v_ptr_list.append(val_cache.data_ptr())
166+
if paddle.is_compiled_with_xpu():
167+
cache_k_ptr_list.append(get_peer_mem_addr(key_cache.data_ptr()))
168+
cache_v_ptr_list.append(get_peer_mem_addr(val_cache.data_ptr()))
169+
else:
170+
cache_k_ptr_list.append(key_cache.data_ptr())
171+
cache_v_ptr_list.append(val_cache.data_ptr())
162172
cache_k_ptr_list = np.array(cache_k_ptr_list)
163173
cache_v_ptr_list = np.array(cache_v_ptr_list)
164174

165175
# 2. initialize the block_bytes
166176
cache_shape = key_cache.shape
167177
max_block_num = cache_shape[0]
168178
block_bytes = math.prod(cache_shape[1:])
169-
if key_cache.dtype == paddle.bfloat16:
179+
if key_cache.dtype == paddle.bfloat16 or key_cache.dtype == paddle.float16:
170180
block_bytes *= 2
171181
logger.info(
172182
f"layers {num_layers} cache_shape: {cache_shape}, max_block_num: {max_block_num}, "
@@ -452,8 +462,12 @@ def __init__(
452462
val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
453463
cache_k.append(key_cache)
454464
cache_v.append(val_cache)
455-
cache_k_ptr_list.append(key_cache.data_ptr())
456-
cache_v_ptr_list.append(val_cache.data_ptr())
465+
if paddle.is_compiled_with_xpu():
466+
cache_k_ptr_list.append(get_peer_mem_addr(key_cache.data_ptr()))
467+
cache_v_ptr_list.append(get_peer_mem_addr(val_cache.data_ptr()))
468+
else:
469+
cache_k_ptr_list.append(key_cache.data_ptr())
470+
cache_v_ptr_list.append(val_cache.data_ptr())
457471
cache_k_ptr_list = np.array(cache_k_ptr_list)
458472
cache_v_ptr_list = np.array(cache_v_ptr_list)
459473

@@ -763,7 +777,7 @@ def _handle_connect_task(self):
763777
def main():
764778
device = args.device_id
765779
rank = args.rank
766-
paddle.set_device(f"gpu:{device}")
780+
set_device(device)
767781
cache_type = args.cache_dtype
768782
speculative_config = SpeculativeConfig(args.speculative_config)
769783
num_extra_layers = speculative_config.num_extra_cache_layer
@@ -823,7 +837,7 @@ def main():
823837
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in gpu_cache_kvs.items()])
824838
logger.info(f"device :{device}")
825839
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
826-
logger.info(f"done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}")
840+
logger.info(f"done init cache (full) gmem alloc : {memory_allocated}")
827841

828842
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
829843
cache_messager = CacheMessagerV1(
@@ -875,7 +889,6 @@ def main():
875889
args = parse_args()
876890
rank_id = args.rank + args.local_data_parallel_id * args.mp_num
877891
logger = get_logger("cache_messager", f"cache_messager_rank{rank_id}.log")
878-
879892
logger.info("create cache messager...")
880893
logger.info(f"{args}")
881894
main()

fastdeploy/cache_manager/ops.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,27 @@
66
from fastdeploy.model_executor.ops.gpu import (
77
cuda_host_alloc,
88
cuda_host_free,
9+
get_data_ptr_ipc,
10+
get_output_kv_signal,
11+
ipc_sent_key_value_cache_by_remote_ptr,
12+
ipc_sent_key_value_cache_by_remote_ptr_block_sync,
913
set_data_ipc,
1014
share_external_data,
1115
swap_cache_all_layers,
1216
unset_data_ipc,
1317
)
1418

1519
memory_allocated = paddle.device.cuda.memory_allocated
20+
21+
def get_peer_mem_addr(*args, **kwargs):
22+
raise RuntimeError("CUDA no need of get_peer_mem_addr!")
23+
1624
elif current_platform.is_xpu():
1725
from fastdeploy.model_executor.ops.xpu import (
1826
cuda_host_alloc,
1927
cuda_host_free,
28+
get_output_kv_signal,
29+
get_peer_mem_addr,
2030
set_data_ipc,
2131
share_external_data,
2232
swap_cache_all_layers,
@@ -25,6 +35,15 @@
2535
unset_data_ipc = None
2636
memory_allocated = paddle.device.xpu.memory_allocated
2737

38+
def get_data_ptr_ipc(*args, **kwargs):
39+
raise RuntimeError("XPU get_data_ptr_ipc UNIMPLENENTED!")
40+
41+
def ipc_sent_key_value_cache_by_remote_ptr(*args, **kwargs):
42+
raise RuntimeError("XPU ipc_sent_key_value_cache_by_remote_ptr UNIMPLENENTED")
43+
44+
def ipc_sent_key_value_cache_by_remote_ptr_block_sync(*args, **kwargs):
45+
raise RuntimeError("XPU No ipc_sent_key_value_cache_by_remote_ptr UNIMPLENENTED")
46+
2847
else:
2948
raise RuntimeError("Prefix cache ops only supported CUDA nor XPU platform ")
3049

@@ -48,6 +67,13 @@ def share_external_data_(cache, cache_name, cache_shape, use_ipc):
4867
return cache
4968

5069

70+
def get_all_visible_devices():
71+
if current_platform.is_xpu():
72+
return "XPU_VISIBLE_DEVICES=0,1,2,3,4,5,6,7"
73+
else:
74+
return "CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7"
75+
76+
5177
__all__ = [
5278
"cuda_host_alloc",
5379
"cuda_host_free",
@@ -57,4 +83,10 @@ def share_external_data_(cache, cache_name, cache_shape, use_ipc):
5783
"unset_data_ipc", # XPU是 None
5884
"set_device",
5985
"memory_allocated",
86+
"get_output_kv_signal",
87+
"get_data_ptr_ipc",
88+
"ipc_sent_key_value_cache_by_remote_ptr",
89+
"ipc_sent_key_value_cache_by_remote_ptr_block_sync",
90+
"get_peer_mem_addr",
91+
"get_all_visible_devices",
6092
]

fastdeploy/cache_manager/prefix_cache_manager.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from fastdeploy import envs
3434
from fastdeploy.cache_manager.cache_data import BlockNode, CacheStatus
3535
from fastdeploy.cache_manager.cache_metrics import CacheMetrics
36+
from fastdeploy.cache_manager.ops import get_all_visible_devices
3637
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, PrefixTreeStatus
3738
from fastdeploy.metrics.metrics import main_process_metrics
3839
from fastdeploy.utils import get_logger
@@ -243,9 +244,11 @@ def launch_cache_manager(
243244
# Run command to launch cache transfer managers
244245
log_dir = envs.FD_LOG_DIR
245246
cache_manager_processes = []
247+
visible_devices = get_all_visible_devices()
246248
for i in range(tensor_parallel_size):
247249
launch_cmd = (
248-
"FLAGS_allocator_strategy=auto_growth CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7"
250+
"FLAGS_allocator_strategy=auto_growth "
251+
+ visible_devices
249252
+ " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0"
250253
+ f" FD_ENABLE_SWAP_SPACE_CLEARING={envs.FD_ENABLE_SWAP_SPACE_CLEARING}"
251254
+ f" {sys.executable} {py_path}"
@@ -328,9 +331,11 @@ def launch_cache_messager(
328331
py_path = os.path.join(current_dir_path, filename)
329332
log_dir = envs.FD_LOG_DIR
330333
cache_messager_processes = []
334+
visible_devices = get_all_visible_devices()
331335
for i in range(tensor_parallel_size):
332336
launch_cmd = (
333-
"FLAGS_allocator_strategy=auto_growth CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7"
337+
"FLAGS_allocator_strategy=auto_growth "
338+
+ visible_devices
334339
+ " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0"
335340
+ f" {sys.executable} {py_path}"
336341
+ f" --device_id {int(device_ids[i])}"

fastdeploy/cache_manager/transfer_factory/ipc_cache_transfer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import paddle
1818

19-
from fastdeploy.model_executor.ops.gpu import (
19+
from fastdeploy.cache_manager.ops import (
2020
get_data_ptr_ipc,
2121
ipc_sent_key_value_cache_by_remote_ptr,
2222
ipc_sent_key_value_cache_by_remote_ptr_block_sync,

0 commit comments

Comments
 (0)