1717#include " ops/utility/env.h"
1818#include " paddle/extension.h"
1919
20+ #ifndef PD_BUILD_STATIC_OP
21+ #define PD_BUILD_STATIC_OP (name ) PD_BUILD_OP(static_op_##name)
22+ #endif
23+
2024XPU_DECLARE_BOOL (fmt_write_cache_completed_signal, false );
2125
2226using cache_write_complete_signal_type =
2327 RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data;
2428
2529paddle::Tensor OpenShmAndGetMetaSignalFunc (const int rank,
30+ const int device_id,
2631 const bool keep_pd_step_flag) {
2732 cache_write_complete_signal_type kv_signal_metadata;
28- const char * fmt_write_cache_completed_signal_str =
33+ const char * fmt_write_cache_completed_signal_str =
2934 std::getenv (" FLAGS_fmt_write_cache_completed_signal" );
3035 if (fmt_write_cache_completed_signal_str &&
3136 (std::strcmp (fmt_write_cache_completed_signal_str, " true" ) == 0 ||
3237 std::strcmp (fmt_write_cache_completed_signal_str, " 1" ) == 0 )) {
3338 kv_signal_metadata =
3439 RemoteCacheKvIpc::open_shm_and_get_complete_signal_meta_data (
35- rank, keep_pd_step_flag);
40+ rank, device_id, keep_pd_step_flag);
3641 }
3742
3843 auto kv_signal_metadata_out =
@@ -46,9 +51,9 @@ paddle::Tensor OpenShmAndGetMetaSignalFunc(const int rank,
4651 return kv_signal_metadata_out;
4752}
4853
49- void InitKVSignalPerQuery (const paddle::Tensor & seq_lens_encoder_tensor,
50- const paddle::Tensor & seq_lens_this_time_tensor,
51- const paddle::Tensor & seq_lens_decoder_tensor,
54+ void InitKVSignalPerQuery (const paddle::Tensor& seq_lens_encoder_tensor,
55+ const paddle::Tensor& seq_lens_this_time_tensor,
56+ const paddle::Tensor& seq_lens_decoder_tensor,
5257 const int rank,
5358 const int num_layers) {
5459 if (FLAGS_fmt_write_cache_completed_signal) {
@@ -68,24 +73,24 @@ void InitKVSignalPerQuery(const paddle::Tensor &seq_lens_encoder_tensor,
6873}
6974
7075std::vector<paddle::Tensor> OpenShmAndGetMetaSignal (
71- const int rank, const bool keep_pd_step_flag) {
72- return {OpenShmAndGetMetaSignalFunc (rank, keep_pd_step_flag)};
76+ const int rank, const int device_id, const bool keep_pd_step_flag) {
77+ return {OpenShmAndGetMetaSignalFunc (rank, device_id, keep_pd_step_flag)};
7378}
7479
7580std::vector<std::vector<int64_t >> OpenShmAndGetMetaSignalShape (
76- const int rank, const bool keep_pd_step_flag) {
81+ const int rank, const int device_id, const bool keep_pd_step_flag) {
7782 return {{3 }};
7883}
7984
8085std::vector<paddle::DataType> OpenShmAndGetMetaSignalDtype (
81- const int rank, const bool keep_pd_step_flag) {
86+ const int rank, const int device_id, const bool keep_pd_step_flag) {
8287 return {paddle::DataType::INT64};
8388}
8489
85- PD_BUILD_OP (open_shm_and_get_meta_signal)
90+ PD_BUILD_STATIC_OP (open_shm_and_get_meta_signal)
8691 .Inputs({})
8792 .Outputs({" kv_signal_metadata" })
88- .Attrs({" rank: int" , " keep_pd_step_flag: bool" })
93+ .Attrs({" rank: int" , " device_id: int " , " keep_pd_step_flag: bool" })
8994 .SetKernelFn(PD_KERNEL(OpenShmAndGetMetaSignal))
9095 .SetInferShapeFn(PD_INFER_SHAPE(OpenShmAndGetMetaSignalShape))
9196 .SetInferDtypeFn(PD_INFER_DTYPE(OpenShmAndGetMetaSignalDtype));
0 commit comments