Skip to content

Commit dda027e

Browse files
[KVPOOl]Support pp (#4761)
### What this PR does / why we need it? Support pp for kv pool - vLLM version: v0.12.0 - vLLM main: vllm-project/vllm@ad32e3e --------- Signed-off-by: baxingpiaochong <[email protected]>
1 parent 9038865 commit dda027e

File tree

3 files changed

+16
-3
lines changed

3 files changed

+16
-3
lines changed

vllm_ascend/distributed/kvpool/config_data.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ class KeyMetadata:
2121
pcp_rank: int
2222
""" Initialize the current decode context model parallel rank """
2323
dcp_rank: int
24+
""" Initialize the current pipeline parallel rank """
25+
pp_rank: int
2426

2527

2628
@dataclass(order=True)
@@ -34,15 +36,16 @@ def __hash__(self):
3436
self.key_metadata.head_or_tp_rank,
3537
self.key_metadata.pcp_rank,
3638
self.key_metadata.dcp_rank,
39+
self.key_metadata.pp_rank,
3740
self.chunk_hash,
3841
))
3942

4043
def to_string(self):
4144
return (
4245
f"{self.key_metadata.model_name}"
4346
f"@pcp{self.key_metadata.pcp_rank}@dcp{self.key_metadata.dcp_rank}"
44-
f"@head_or_tp_rank:{self.key_metadata.head_or_tp_rank}@{self.chunk_hash}"
45-
)
47+
f"@head_or_tp_rank:{self.key_metadata.head_or_tp_rank}"
48+
f"@pp_rank:{self.key_metadata.pp_rank}@{self.chunk_hash}")
4649

4750
def split_layers(self, num_layers: int) -> List["LayerPoolKey"]:
4851
"""Split the key into multiple keys for each layer"""

vllm_ascend/distributed/kvpool/pool_worker.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ def __init__(
4848
self.use_layerwise = use_layerwize
4949
self.tp_rank = get_tensor_model_parallel_rank()
5050
self.tp_size = get_tensor_model_parallel_world_size()
51+
self.pp_size = parallel_config.pipeline_parallel_size
52+
self.pp_rank = (parallel_config.rank // self.tp_size) % self.pp_size
5153

5254
self.pcp_size = get_pcp_group().world_size
5355
self.pcp_rank = get_pcp_group(
@@ -87,6 +89,7 @@ def __init__(
8789
self.head_or_tp_rank,
8890
self.pcp_rank,
8991
self.dcp_rank,
92+
self.pp_rank,
9093
)
9194

9295
self.token_database = ChunkedTokenDatabase(self.metadata,
@@ -555,6 +558,12 @@ def lookup_scheduler(
555558
"@head_or_tp_rank:0", f"@head_or_tp_rank:{i}", 1)
556559
multi_tp_keys.append(new_str)
557560

561+
for i in range(1, self.pp_size):
562+
for item in keys:
563+
new_str = item.replace( # type: ignore[attr-defined]
564+
"@pp_rank:0", f"@pp_rank:{i}", 1)
565+
multi_tp_keys.append(new_str)
566+
558567
res = self.m_store.exists(
559568
multi_tp_keys) # type: ignore[assignment]
560569
num_block = len(keys)

vllm_ascend/worker/model_runner_v1.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2450,6 +2450,7 @@ def execute_model(
24502450
attn_metadata, self.with_prefill, maybe_padded_num_tokens,
24512451
input_ids, positions, intermediate_tensors, inputs_embeds)
24522452

2453+
self.maybe_wait_for_kv_save()
24532454
finished_sending, finished_recving = self.get_finished_kv_transfer(
24542455
scheduler_output)
24552456

@@ -2711,7 +2712,7 @@ def propose_draft_token_ids(sampled_token_ids):
27112712
# ngram and other speculative decoding methods use the sampled
27122713
# tokens on the CPU, so they are run after bookkeeping.
27132714
propose_draft_token_ids(valid_sampled_token_ids)
2714-
self.maybe_wait_for_kv_save()
2715+
27152716
if has_kv_transfer_group():
27162717
get_kv_transfer_group().clear_connector_metadata()
27172718

0 commit comments

Comments
 (0)