Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions kvcached/integration/vllm/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,18 @@
def should_use_worker_ipc() -> bool:
return _kvcached_initialized and not _is_worker


def get_world_size(default: int = 1) -> int:
try:
return int(_world_size)
except (TypeError, ValueError):
logger.warning(
"Invalid recorded TP world size %r; falling back to %d",
_world_size,
default,
)
return default

def init_kvcached(
tp_rank: int = 0,
world_size: int = 1,
Expand Down
13 changes: 6 additions & 7 deletions kvcached/integration/vllm/patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,15 +713,14 @@ def _setup_kvcached_coordinator(self) -> None:
cell_size, num_kv_buffers = _get_kv_cache_params(
kv_cache_spec, block_size, attention_type=attention_type)

try:
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size

tp_size = int(get_tensor_model_parallel_world_size())
except Exception:
tp_size = 1

from kvcached.integration.vllm import interfaces as kvi

# Reuse the TP world size recorded during EngineCore init.
# At coordinator construction time, vLLM's parallel_state helpers
# can still observe world_size=1 even though TP workers will be
# launched with the correct tensor_parallel_size later in startup.
tp_size = kvi.get_world_size()

# Use tp_size (not TP*PP global world size) for the KVCacheManager world_size.
# Each PP stage manages its own KV tensors independently. The IPC sockets
# are registered per TP rank within each stage (w0.sock … w(tp_size-1).sock).
Expand Down
103 changes: 103 additions & 0 deletions tests/test_tp_world_size_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# SPDX-FileCopyrightText: Copyright contributors to the kvcached project
# SPDX-License-Identifier: Apache-2.0

import importlib
import sys
import types
from importlib.machinery import ModuleSpec
from unittest import mock

import pytest


def _load_patches(monkeypatch):
torch_mock = mock.MagicMock()
torch_mock.__version__ = "2.6.0"
monkeypatch.setitem(sys.modules, "torch", torch_mock)
monkeypatch.setitem(sys.modules, "torch.cuda", torch_mock.cuda)
monkeypatch.setitem(sys.modules, "torch.utils", torch_mock.utils)
monkeypatch.setitem(
sys.modules,
"torch.utils.cpp_extension",
torch_mock.utils.cpp_extension,
)
monkeypatch.setitem(sys.modules, "posix_ipc", mock.MagicMock())
monkeypatch.setitem(sys.modules, "kvcached.vmm_ops", mock.MagicMock())

interfaces_mod = types.ModuleType("kvcached.integration.vllm.interfaces")
interfaces_mod.get_world_size = mock.Mock(return_value=2)
interfaces_mod.init_kvcached = mock.Mock()
monkeypatch.setitem(
sys.modules,
"kvcached.integration.vllm.interfaces",
interfaces_mod,
)

parallel_state_mod = types.ModuleType("vllm.distributed.parallel_state")
parallel_state_mod.get_tensor_model_parallel_world_size = lambda: 1
parallel_state_mod.__spec__ = ModuleSpec(
"vllm.distributed.parallel_state",
loader=None,
)
monkeypatch.setitem(
sys.modules,
"vllm.distributed.parallel_state",
parallel_state_mod,
)

vllm_mod = types.ModuleType("vllm")
vllm_mod.__spec__ = ModuleSpec("vllm", loader=None)
monkeypatch.setitem(sys.modules, "vllm", vllm_mod)

vllm_distributed_mod = types.ModuleType("vllm.distributed")
vllm_distributed_mod.__spec__ = ModuleSpec("vllm.distributed", loader=None)
monkeypatch.setitem(sys.modules, "vllm.distributed", vllm_distributed_mod)

patches = importlib.import_module("kvcached.integration.vllm.patches")
return importlib.reload(patches), interfaces_mod


class FakeElasticBlockPool:
def __init__(self, *args, **kwargs):
self.null_block = object()


def test_kv_cache_coordinator_reuses_enginecore_world_size(monkeypatch):
patches, interfaces_mod = _load_patches(monkeypatch)

monkeypatch.setattr(patches, "enable_kvcached", lambda: True)
monkeypatch.setattr(patches, "_validate_kv_cache_groups", lambda cfg: None)
monkeypatch.setattr(
patches,
"_get_first_attention_group",
lambda cfg: types.SimpleNamespace(kv_cache_spec=types.SimpleNamespace(block_size=16)),
)
monkeypatch.setattr(patches, "_infer_attention_type", lambda cfg: "MHA")
monkeypatch.setattr(patches, "_get_kv_cache_params", lambda *args, **kwargs: (1024, 2))
monkeypatch.setattr(patches, "_get_group_size", lambda cfg: 1)
monkeypatch.setattr(patches, "_get_max_cached_blocks", lambda block_size: 0)
monkeypatch.setattr(patches, "_should_enable_async_sched", lambda cfg: False)

fake_block_pool_mod = types.ModuleType("vllm.v1.core.block_pool")
fake_block_pool_mod.ElasticBlockPool = FakeElasticBlockPool
monkeypatch.setitem(sys.modules, "vllm.v1.core.block_pool", fake_block_pool_mod)

kvcoord_mod = types.ModuleType("mock_kvcoord_mod")

class FakeKVCacheCoordinator:
def __init__(self, *args, **kwargs):
self.enable_caching = False
self.kv_cache_config = types.SimpleNamespace(num_blocks=8)
self.single_type_managers = [types.SimpleNamespace()]

kvcoord_mod.KVCacheCoordinator = FakeKVCacheCoordinator

patch = patches.KVCacheCoordinatorPatch()
assert patch.patch_coordinator(kvcoord_mod)

coordinator = kvcoord_mod.KVCacheCoordinator()

interfaces_mod.get_world_size.assert_called_once_with()
interfaces_mod.init_kvcached.assert_called_once()
assert interfaces_mod.init_kvcached.call_args.kwargs["world_size"] == 2
assert isinstance(coordinator.block_pool, FakeElasticBlockPool)