Skip to content

Commit beb151b

Browse files
author
ddchenhao66
committed
[XPU] xpu support PD disaggregation
1 parent 6c5ab72 commit beb151b

File tree

13 files changed

+236
-75
lines changed

13 files changed

+236
-75
lines changed
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "ops/remote_cache_kv_ipc.h"
16+
#include "paddle/extension.h"
17+
18+
#ifndef PD_BUILD_STATIC_OP
19+
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
20+
#endif
21+
22+
using cache_write_complete_signal_type =
23+
RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data;
24+
25+
paddle::Tensor InitSignalLayerwiseFunc(const paddle::Tensor& kv_signal_metadata,
26+
const int layer_id) {
27+
auto kv_signal_metadata_out =
28+
kv_signal_metadata.copy_to(paddle::CPUPlace(), false);
29+
kv_signal_metadata_out.data<int64_t>()[0] = static_cast<int64_t>(layer_id);
30+
return kv_signal_metadata_out;
31+
}
32+
33+
std::vector<paddle::Tensor> InitSignalLayerwise(
34+
const paddle::Tensor& kv_signal_metadata, const int layer_id) {
35+
return {InitSignalLayerwiseFunc(kv_signal_metadata, layer_id)};
36+
}
37+
38+
std::vector<std::vector<int64_t>> InitSignalLayerwiseShape(
39+
const std::vector<int64_t>& kv_signal_metadata_shape, const int layer_id) {
40+
return {kv_signal_metadata_shape};
41+
}
42+
43+
std::vector<paddle::DataType> InitSignalLayerwiseDtype(
44+
const paddle::DataType& kv_signal_metadata_dtype, const int layer_id) {
45+
return {paddle::DataType::INT64};
46+
}
47+
48+
PD_BUILD_STATIC_OP(init_signal_layerwise)
49+
.Inputs({"kv_signal_metadata"})
50+
.Outputs({"kv_signal_metadata_out"})
51+
.Attrs({"layer_id: int"})
52+
.SetKernelFn(PD_KERNEL(InitSignalLayerwise))
53+
.SetInferShapeFn(PD_INFER_SHAPE(InitSignalLayerwiseShape))
54+
.SetInferDtypeFn(PD_INFER_DTYPE(InitSignalLayerwiseDtype));

custom_ops/xpu_ops/src/ops/open_shm_and_get_meta_signal.cc

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,27 @@
1717
#include "ops/utility/env.h"
1818
#include "paddle/extension.h"
1919

20+
#ifndef PD_BUILD_STATIC_OP
21+
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
22+
#endif
23+
2024
XPU_DECLARE_BOOL(fmt_write_cache_completed_signal, false);
2125

2226
using cache_write_complete_signal_type =
2327
RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data;
2428

2529
paddle::Tensor OpenShmAndGetMetaSignalFunc(const int rank,
30+
const int device_id,
2631
const bool keep_pd_step_flag) {
2732
cache_write_complete_signal_type kv_signal_metadata;
28-
const char *fmt_write_cache_completed_signal_str =
33+
const char* fmt_write_cache_completed_signal_str =
2934
std::getenv("FLAGS_fmt_write_cache_completed_signal");
3035
if (fmt_write_cache_completed_signal_str &&
3136
(std::strcmp(fmt_write_cache_completed_signal_str, "true") == 0 ||
3237
std::strcmp(fmt_write_cache_completed_signal_str, "1") == 0)) {
3338
kv_signal_metadata =
3439
RemoteCacheKvIpc::open_shm_and_get_complete_signal_meta_data(
35-
rank, keep_pd_step_flag);
40+
rank, device_id, keep_pd_step_flag);
3641
}
3742

3843
auto kv_signal_metadata_out =
@@ -46,9 +51,9 @@ paddle::Tensor OpenShmAndGetMetaSignalFunc(const int rank,
4651
return kv_signal_metadata_out;
4752
}
4853

49-
void InitKVSignalPerQuery(const paddle::Tensor &seq_lens_encoder_tensor,
50-
const paddle::Tensor &seq_lens_this_time_tensor,
51-
const paddle::Tensor &seq_lens_decoder_tensor,
54+
void InitKVSignalPerQuery(const paddle::Tensor& seq_lens_encoder_tensor,
55+
const paddle::Tensor& seq_lens_this_time_tensor,
56+
const paddle::Tensor& seq_lens_decoder_tensor,
5257
const int rank,
5358
const int num_layers) {
5459
if (FLAGS_fmt_write_cache_completed_signal) {
@@ -68,24 +73,24 @@ void InitKVSignalPerQuery(const paddle::Tensor &seq_lens_encoder_tensor,
6873
}
6974

7075
std::vector<paddle::Tensor> OpenShmAndGetMetaSignal(
71-
const int rank, const bool keep_pd_step_flag) {
72-
return {OpenShmAndGetMetaSignalFunc(rank, keep_pd_step_flag)};
76+
const int rank, const int device_id, const bool keep_pd_step_flag) {
77+
return {OpenShmAndGetMetaSignalFunc(rank, device_id, keep_pd_step_flag)};
7378
}
7479

7580
std::vector<std::vector<int64_t>> OpenShmAndGetMetaSignalShape(
76-
const int rank, const bool keep_pd_step_flag) {
81+
const int rank, const int device_id, const bool keep_pd_step_flag) {
7782
return {{3}};
7883
}
7984

8085
std::vector<paddle::DataType> OpenShmAndGetMetaSignalDtype(
81-
const int rank, const bool keep_pd_step_flag) {
86+
const int rank, const int device_id, const bool keep_pd_step_flag) {
8287
return {paddle::DataType::INT64};
8388
}
8489

85-
PD_BUILD_OP(open_shm_and_get_meta_signal)
90+
PD_BUILD_STATIC_OP(open_shm_and_get_meta_signal)
8691
.Inputs({})
8792
.Outputs({"kv_signal_metadata"})
88-
.Attrs({"rank: int", "keep_pd_step_flag: bool"})
93+
.Attrs({"rank: int", "device_id: int", "keep_pd_step_flag: bool"})
8994
.SetKernelFn(PD_KERNEL(OpenShmAndGetMetaSignal))
9095
.SetInferShapeFn(PD_INFER_SHAPE(OpenShmAndGetMetaSignalShape))
9196
.SetInferDtypeFn(PD_INFER_DTYPE(OpenShmAndGetMetaSignalDtype));

custom_ops/xpu_ops/src/ops/remote_cache_kv_ipc.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ bool RemoteCacheKvIpc::kv_complete_signal_shmem_opened = false;
2626

2727
RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data
2828
RemoteCacheKvIpc::open_shm_and_get_complete_signal_meta_data(
29-
const int rank_id, const bool keep_pd_step_flag) {
29+
const int rank_id, const int device_id, const bool keep_pd_step_flag) {
3030
if (RemoteCacheKvIpc::kv_complete_signal_shmem_opened) {
3131
if (keep_pd_step_flag) {
3232
return RemoteCacheKvIpc::kv_complete_signal_meta_data;
@@ -47,12 +47,13 @@ RemoteCacheKvIpc::open_shm_and_get_complete_signal_meta_data(
4747
std::string iflags_server_uuid_env_str(iflags_server_uuid_env_p);
4848
flags_server_uuid = iflags_server_uuid_env_str;
4949
}
50+
5051
std::string step_shm_name =
51-
("splitwise_complete_prefilled_step_" + std::to_string(rank_id) + "_" +
52-
flags_server_uuid);
52+
("splitwise_complete_prefilled_step_" + std::to_string(rank_id) + "." +
53+
std::to_string(device_id));
5354
std::string layer_shm_name =
54-
("splitwise_complete_prefilled_layer_" + std::to_string(rank_id) + "_" +
55-
flags_server_uuid);
55+
("splitwise_complete_prefilled_layer_" + std::to_string(rank_id) + "." +
56+
std::to_string(device_id));
5657
if (const char* use_ep = std::getenv("ENABLE_EP_DP")) {
5758
if (std::strcmp(use_ep, "1") == 0) {
5859
step_shm_name = "splitwise_complete_prefilled_step_tprank0_dprank" +

custom_ops/xpu_ops/src/ops/remote_cache_kv_ipc.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ struct RemoteCacheKvIpc {
9191

9292
static RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data
9393
open_shm_and_get_complete_signal_meta_data(const int rank_id,
94+
const int device_id,
9495
const bool keep_pd_step_flag);
9596
static void save_cache_kv_complete_signal_layerwise(void* meta_data);
9697
static void save_cache_kv_complete_signal_layerwise_per_query(

custom_ops/xpu_ops/src/ops/share_external_data.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,26 +19,26 @@
1919
#include "xpu/plugin.h"
2020
#include "xpu_multiprocess.h" // NOLINT(build/include_subdir)
2121

22-
std::vector<paddle::Tensor> ShareExternalData(const paddle::Tensor &input,
22+
std::vector<paddle::Tensor> ShareExternalData(const paddle::Tensor& input,
2323
const std::string shm_name,
24-
const std::vector<int> &shape,
24+
const std::vector<int>& shape,
2525
bool use_ipc) {
2626
sharedMemoryInfo info;
2727
int ret = sharedMemoryOpen(shm_name.c_str(), sizeof(shmStruct), &info);
2828
PD_CHECK(ret == 0, "sharedMemoryOpen failed");
29-
volatile shmStruct *shm = static_cast<volatile shmStruct *>(info.addr);
30-
void *data_ptr_addr = nullptr;
29+
volatile shmStruct* shm = static_cast<volatile shmStruct*>(info.addr);
30+
void* data_ptr_addr = nullptr;
3131
if (use_ipc) {
3232
#if XPURT_VERSION_MAJOR == 5
3333
int ret = xpu_ipc_open_memhandle(&data_ptr_addr,
34-
*(XPUIpcMemHandle *)&shm->memHandle,
34+
*(XPUIpcMemHandle*)&shm->memHandle,
3535
0x01); // NOLINT
36-
PD_CHECK(ret == XPU_SUCCESS, "xpu_ipc_open_memhandle failed");
36+
PD_CHECK(ret == XPU_SUCCESS, "%s xpu_ipc_open_memhandle failed", shm_name);
3737
#elif XPURT_VERSION_MAJOR == 4
3838
PD_THROW("kl2 not support prefix cache");
3939
#endif
4040
} else {
41-
data_ptr_addr = reinterpret_cast<void *>(shm->data_ptr_addr);
41+
data_ptr_addr = reinterpret_cast<void*>(shm->data_ptr_addr);
4242
}
4343

4444
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());

fastdeploy/cache_manager/cache_messager.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,20 @@
2525
import numpy as np
2626
import paddle
2727

28+
from fastdeploy.cache_manager.ops import (
29+
get_output_kv_signal,
30+
get_peer_mem_addr,
31+
memory_allocated,
32+
set_data_ipc,
33+
set_device,
34+
)
2835
from fastdeploy.cache_manager.transfer_factory import IPCCommManager, RDMACommManager
2936
from fastdeploy.config import SpeculativeConfig
3037
from fastdeploy.inter_communicator import (
3138
EngineWorkerQueue,
3239
IPCSignal,
3340
shared_memory_exists,
3441
)
35-
from fastdeploy.model_executor.ops.gpu import get_output_kv_signal, set_data_ipc
3642
from fastdeploy.utils import envs, get_logger
3743

3844
logger = get_logger("cache_messager", "cache_messager.log")
@@ -155,18 +161,25 @@ def __init__(
155161
for layer_idx in range(self.num_layers):
156162
key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
157163
val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
164+
logger.info(
165+
f"[key_cache: {hex(key_cache.data_ptr())}],[key_cache_mem: {hex(get_peer_mem_addr(key_cache.data_ptr()))}]"
166+
)
158167
cache_k.append(key_cache)
159168
cache_v.append(val_cache)
160-
cache_k_ptr_list.append(key_cache.data_ptr())
161-
cache_v_ptr_list.append(val_cache.data_ptr())
169+
if paddle.is_compiled_with_xpu():
170+
cache_k_ptr_list.append(get_peer_mem_addr(key_cache.data_ptr()))
171+
cache_v_ptr_list.append(get_peer_mem_addr(val_cache.data_ptr()))
172+
else:
173+
cache_k_ptr_list.append(key_cache.data_ptr())
174+
cache_v_ptr_list.append(val_cache.data_ptr())
162175
cache_k_ptr_list = np.array(cache_k_ptr_list)
163176
cache_v_ptr_list = np.array(cache_v_ptr_list)
164177

165178
# 2. initialize the block_bytes
166179
cache_shape = key_cache.shape
167180
max_block_num = cache_shape[0]
168181
block_bytes = math.prod(cache_shape[1:])
169-
if key_cache.dtype == paddle.bfloat16:
182+
if key_cache.dtype == paddle.bfloat16 or key_cache.dtype == paddle.float16:
170183
block_bytes *= 2
171184
logger.info(
172185
f"layers {num_layers} cache_shape: {cache_shape}, max_block_num: {max_block_num}, "
@@ -758,7 +771,7 @@ def _handle_connect_task(self):
758771
def main():
759772
device = args.device_id
760773
rank = args.rank
761-
paddle.set_device(f"gpu:{device}")
774+
set_device(args.rank)
762775
cache_type = args.cache_dtype
763776
speculative_config = SpeculativeConfig(args.speculative_config)
764777
num_extra_layers = speculative_config.num_extra_cache_layer
@@ -818,7 +831,7 @@ def main():
818831
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in gpu_cache_kvs.items()])
819832
logger.info(f"device :{device}")
820833
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
821-
logger.info(f"done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}")
834+
logger.info(f"done init cache (full) gmem alloc : {memory_allocated}")
822835

823836
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
824837
cache_messager = CacheMessagerV1(

fastdeploy/cache_manager/ops.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,27 @@
66
from fastdeploy.model_executor.ops.gpu import (
77
cuda_host_alloc,
88
cuda_host_free,
9+
get_data_ptr_ipc,
10+
get_output_kv_signal,
11+
ipc_sent_key_value_cache_by_remote_ptr,
12+
ipc_sent_key_value_cache_by_remote_ptr_block_sync,
913
set_data_ipc,
1014
share_external_data,
1115
swap_cache_all_layers,
1216
unset_data_ipc,
1317
)
1418

1519
memory_allocated = paddle.device.cuda.memory_allocated
20+
21+
def get_peer_mem_addr(*args, **kwargs):
22+
raise RuntimeError("CUDA no need of get_peer_mem_addr!")
23+
1624
elif current_platform.is_xpu():
1725
from fastdeploy.model_executor.ops.xpu import (
1826
cuda_host_alloc,
1927
cuda_host_free,
28+
get_output_kv_signal,
29+
get_peer_mem_addr,
2030
set_data_ipc,
2131
share_external_data,
2232
swap_cache_all_layers,
@@ -25,6 +35,15 @@
2535
unset_data_ipc = None
2636
memory_allocated = paddle.device.xpu.memory_allocated
2737

38+
def get_data_ptr_ipc(*args, **kwargs):
39+
raise RuntimeError("XPU get_data_ptr_ipc UNIMPLENENTED!")
40+
41+
def ipc_sent_key_value_cache_by_remote_ptr(*args, **kwargs):
42+
raise RuntimeError("XPU ipc_sent_key_value_cache_by_remote_ptr UNIMPLENENTED")
43+
44+
def ipc_sent_key_value_cache_by_remote_ptr_block_sync(*args, **kwargs):
45+
raise RuntimeError("XPU No ipc_sent_key_value_cache_by_remote_ptr UNIMPLENENTED")
46+
2847
else:
2948
raise RuntimeError("Prefix cache ops only supported CUDA nor XPU platform ")
3049

@@ -57,4 +76,9 @@ def share_external_data_(cache, cache_name, cache_shape, use_ipc):
5776
"unset_data_ipc", # XPU是 None
5877
"set_device",
5978
"memory_allocated",
79+
"get_output_kv_signal",
80+
"get_data_ptr_ipc",
81+
"ipc_sent_key_value_cache_by_remote_ptr",
82+
"ipc_sent_key_value_cache_by_remote_ptr_block_sync",
83+
"get_peer_mem_addr",
6084
]

fastdeploy/cache_manager/transfer_factory/ipc_cache_transfer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import paddle
1818

19-
from fastdeploy.model_executor.ops.gpu import (
19+
from fastdeploy.cache_manager.ops import (
2020
get_data_ptr_ipc,
2121
ipc_sent_key_value_cache_by_remote_ptr,
2222
ipc_sent_key_value_cache_by_remote_ptr_block_sync,

fastdeploy/model_executor/forward_meta.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,8 @@ class XPUForwardMeta(ForwardMeta):
244244
total_enc_len: Optional[paddle.Tensor] = None
245245
# position embedding type in rope, supports 'NORMAL' or 'HALF_HEAD_DIM'
246246
pos_emb_type: Optional[str] = "NORMAL"
247+
# for pd_disaggregation
248+
kv_signal_sender: Optional[paddle.Tensor] = None
247249

248250

249251
@dataclass

fastdeploy/model_executor/layers/attention/ops/init_signal_layerwise.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@ def init_signal_layerwise(
2929
if current_platform.is_cuda():
3030
from fastdeploy.model_executor.ops.gpu import init_signal_layerwise
3131

32+
out = init_signal_layerwise(kv_signal_metadata, layer_id)
33+
return out
34+
elif current_platform.is_xpu():
35+
from fastdeploy.model_executor.ops.xpu import init_signal_layerwise
36+
3237
out = init_signal_layerwise(kv_signal_metadata, layer_id)
3338
return out
3439
else:

0 commit comments

Comments
 (0)