Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions vllm_ascend/distributed/kvpool/config_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ class KeyMetadata:
pcp_rank: int
""" Initialize the current decode context model parallel rank """
dcp_rank: int
""" Initialize the current pipeline parallel rank """
pp_rank: int


@dataclass(order=True)
Expand All @@ -34,15 +36,16 @@ def __hash__(self):
self.key_metadata.head_or_tp_rank,
self.key_metadata.pcp_rank,
self.key_metadata.dcp_rank,
self.key_metadata.pp_rank,
self.chunk_hash,
))

def to_string(self):
return (
f"{self.key_metadata.model_name}"
f"@pcp{self.key_metadata.pcp_rank}@dcp{self.key_metadata.dcp_rank}"
f"@head_or_tp_rank:{self.key_metadata.head_or_tp_rank}@{self.chunk_hash}"
)
f"@head_or_tp_rank:{self.key_metadata.head_or_tp_rank}"
f"@pp_rank:{self.key_metadata.pp_rank}@{self.chunk_hash}")

def split_layers(self, num_layers: int) -> List["LayerPoolKey"]:
"""Split the key into multiple keys for each layer"""
Expand Down
9 changes: 9 additions & 0 deletions vllm_ascend/distributed/kvpool/pool_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def __init__(
self.use_layerwise = use_layerwize
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.pp_size = parallel_config.pipeline_parallel_size
self.pp_rank = (parallel_config.rank // self.tp_size) % self.pp_size

self.pcp_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group(
Expand Down Expand Up @@ -87,6 +89,7 @@ def __init__(
self.head_or_tp_rank,
self.pcp_rank,
self.dcp_rank,
self.pp_rank,
)

self.token_database = ChunkedTokenDatabase(self.metadata,
Expand Down Expand Up @@ -555,6 +558,12 @@ def lookup_scheduler(
"@head_or_tp_rank:0", f"@head_or_tp_rank:{i}", 1)
multi_tp_keys.append(new_str)

for i in range(1, self.pp_size):
for item in keys:
new_str = item.replace( # type: ignore[attr-defined]
"@pp_rank:0", f"@pp_rank:{i}", 1)
multi_tp_keys.append(new_str)
Comment on lines +561 to +565
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current logic for checking key existence across both tensor (TP) and pipeline parallel (PP) ranks is flawed.

  1. Incomplete key generation: It fails to generate keys for all (TP, PP) rank combinations, only checking for (TP=i, PP=0) and (TP=0, PP=j). This will result in missed cache hits when both TP and PP are greater than 1.
  2. Incorrect result processing: The subsequent result processing logic (lines 573-576) is not updated for pipeline parallelism and will likely fail with an IndexError or produce incorrect results.

The entire key generation and result processing block needs to be refactored. Additionally, the implementation relies on hardcoded rank:0 strings, which is brittle. A more robust solution would replace the current worker's rank in the key string.

Here is a corrected implementation for lines 554-579 that addresses the combination issue, assuming this lookup is always performed from a worker with tp_rank=0 and pp_rank=0:

            multi_tp_keys = []
            for pp_i in range(self.pp_size):
                for tp_i in range(min(self.tp_size, self.num_kv_head)):
                    for item in keys:
                        item_with_pp = item.replace("@pp_rank:0", f"@pp_rank:{pp_i}", 1)
                        new_str = item_with_pp.replace("@head_or_tp_rank:0", f"@head_or_tp_rank:{tp_i}", 1)
                        multi_tp_keys.append(new_str)

            res = self.m_store.exists(
                multi_tp_keys)  # type: ignore[assignment]
            num_block = len(keys)
            if use_layerwise:
                res = self.check_all_layers_exists(res, self.num_layers)
                num_block = len(keys) // self.num_layers

            num_ranks = self.pp_size * min(self.tp_size, self.num_kv_head)
            multi_rank_values = [
                res[i * num_block:(i + 1) * num_block]
                for i in range(num_ranks)
            ]
            index = self.find_min_first_non_one_index(multi_rank_values)
            if index != -1:
                return starts[index]


res = self.m_store.exists(
multi_tp_keys) # type: ignore[assignment]
num_block = len(keys)
Expand Down
Loading