From 577041023c75edb2fabffe48be30b02dfd1d93db Mon Sep 17 00:00:00 2001 From: Guan Luo Date: Mon, 6 Oct 2025 19:30:56 -0700 Subject: [PATCH 01/19] fix: missing NIXL metadata for handshake initialization if instance spans multi-node Signed-off-by: Guan Luo --- .../kv_transfer/kv_connector/v1/base.py | 29 +++ .../kv_connector/v1/nixl_connector.py | 175 +++++++++++------- vllm/executor/uniproc_executor.py | 6 +- vllm/v1/engine/core.py | 22 +++ vllm/v1/executor/abstract.py | 4 + vllm/v1/worker/gpu_worker.py | 23 ++- 6 files changed, 187 insertions(+), 72 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 70225e95aed2..47ab3083b5d6 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -77,6 +77,14 @@ class KVConnectorRole(enum.Enum): WORKER = 1 +class KVConnectorHandshakeMetadata(ABC): # noqa: B024 + """ + Metadata used for out of band connector handshakeandshake between + P/D workers. This needs to serializeable. + """ + pass + + class KVConnectorMetadata(ABC): # noqa: B024 """ Abstract Metadata used to communicate between the @@ -271,6 +279,18 @@ def get_kv_connector_stats(self) -> Optional["KVConnectorStats"]: """ return None + def get_handshake_metadata(self) -> Optional[KVConnectorHandshakeMetadata]: + """ + Get the KVConnector handshake metadata for this connector. + This metadata is used for out-of-band connector handshake + between P/D workers. + + Returns: + KVConnectorHandshakeMetadata: the handshake metadata. + None if no handshake metadata is available. + """ + return None + # ============================== # Scheduler-side methods # ============================== @@ -422,3 +442,12 @@ def build_kv_connector_stats( which can implement custom aggregation logic on the data dict. """ return None + + def set_xfer_handshake_metadata(self, metadata: dict[int, dict[int, dict]]) -> None: + """ + Set the KV connector handshake metadata for this connector. + + Args: + metadata (dict): the handshake metadata to set. + """ + pass diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index e3e3389fd164..0d2a2aa666d8 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -27,6 +27,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import ( CopyBlocksOp, KVConnectorBase_V1, + KVConnectorHandshakeMetadata, KVConnectorMetadata, KVConnectorRole, ) @@ -88,12 +89,7 @@ _NIXL_SUPPORTED_DEVICE.update(current_platform.get_nixl_supported_devices()) -class NixlAgentMetadata( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - # required for @cached_property. - dict=True, -): +class NixlAgentMetadata(KVConnectorHandshakeMetadata): engine_id: str agent_metadata: bytes kv_caches_base_addr: list[int] @@ -217,6 +213,16 @@ def request_finished( ) -> tuple[bool, Optional[dict[str, Any]]]: assert self.connector_scheduler is not None return self.connector_scheduler.request_finished(request, block_ids) + + def set_xfer_handshake_metadata(self, metadata: dict[int, dict[int, dict]]) -> None: + """ + Set the KV connector handshake metadata for this connector. + + Args: + metadata (dict): the handshake metadata to set. + """ + assert self.connector_scheduler is not None + self.connector_scheduler.set_xfer_handshake_metadata(metadata) ############################################################ # Worker Side Methods @@ -276,6 +282,21 @@ def wait_for_save(self): def shutdown(self): if self.connector_worker is not None: self.connector_worker.shutdown() + if self.connector_scheduler is not None: + self.connector_scheduler.shutdown() + + def get_handshake_metadata(self) -> Optional[KVConnectorHandshakeMetadata]: + """ + Get the KVConnector handshake metadata for this connector. + This metadata is used for out-of-band connector handshake + between P/D workers. + + Returns: + KVConnectorHandshakeMetadata: the handshake metadata. + None if no handshake metadata is available. + """ + assert self.connector_worker is not None + return self.connector_worker.xfer_handshake_metadata class NixlConnectorScheduler: @@ -286,14 +307,14 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.block_size = vllm_config.cache_config.block_size self.engine_id: EngineId = engine_id self.side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST - self.side_channel_port = ( - envs.VLLM_NIXL_SIDE_CHANNEL_PORT - + vllm_config.parallel_config.data_parallel_rank - * vllm_config.parallel_config.tensor_parallel_size - ) + self.side_channel_port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT self.use_host_buffer = vllm_config.kv_transfer_config.kv_buffer_device == "cpu" logger.info("Initializing NIXL Scheduler %s", engine_id) + # Background thread for handling new handshake requests. + self._nixl_handshake_listener_t: Optional[threading.Thread] = None + self._encoded_xfer_handshake_metadata: dict[int, dict[int, Any]] = {} + # Requests that need to start recv/send. # New requests are added by update_state_after_alloc in # the scheduler. Used to make metadata passed to Worker. @@ -306,6 +327,67 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # remote prefill or aborted. self._reqs_not_processed: set[ReqId] = set() + def shutdown(self): + if self._nixl_handshake_listener_t is not None: + self._nixl_handshake_listener_t.join(timeout=0) + self._nixl_handshake_listener_t = None + + def set_xfer_handshake_metadata(self, metadata: dict[int, dict[int, dict]]) -> None: + """ + Set the KV connector handshake metadata for this connector. + + Args: + metadata (dict): the handshake metadata to set. + """ + encoded_data = {} + for dp_rank, tp_metadata in metadata.items(): + encoded_data[dp_rank] = {} + for tp_rank, rank_metadata in tp_metadata.items(): + encoder = msgspec.msgpack.Encoder() + encoded_data[dp_rank][tp_rank] = encoder.encode(rank_metadata) + logger.debug("Dp rank %d, Tp rank %d: Size of encoded NixlAgentMetadata: %s bytes", + dp_rank, tp_rank, str(len(encoded_data[dp_rank][tp_rank]))) + self._encoded_xfer_handshake_metadata = encoded_data + + # Only start the listener when we have metadata to serve. + if self._nixl_handshake_listener_t is None: + ready_event = threading.Event() + self._nixl_handshake_listener_t = threading.Thread( + target=self._nixl_handshake_listener, + args=(encoded_data, ready_event, self.side_channel_port), + daemon=True, + name="nixl_handshake_listener", + ) + self._nixl_handshake_listener_t.start() + ready_event.wait() # Wait for listener ZMQ socket to be ready. + + @staticmethod + def _nixl_handshake_listener( + encoded_data: dict[int, dict[int, Any]], + ready_event: threading.Event, + port: int, + ): + """Background thread for getting new NIXL handshakes.""" + # NOTE(rob): this is a simple implementation. We will move + # to a better approach via HTTP endpoint soon. + + # Listen for new requests for metadata. + host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST + path = make_zmq_path("tcp", host, port) + logger.debug("Starting listening on path: %s", path) + with zmq_ctx(zmq.ROUTER, path) as sock: + ready_event.set() + while True: + identity, _, msg = sock.recv_multipart() + # Decode the message which contains (GET_META_MSG, rank) + msg, target_dp_rank, target_tp_rank = msgspec.msgpack.decode(msg) + logger.debug("Received message for dp rank %s tp rank %s", target_dp_rank, target_tp_rank) + if msg != GET_META_MSG: + logger.warning( + "Connection listener got unexpected message %s", msg) + sock.send_multipart( + (identity, b"", encoded_data[target_dp_rank][target_tp_rank])) + def get_num_new_matched_tokens( self, request: "Request", num_computed_tokens: int ) -> tuple[int, bool]: @@ -536,16 +618,6 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}. self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict) - # NIXL handshake port. - # NOTE(rob): Within a DP group, each DP rank gets its own - # base port (which is sent in the KVTransferParams). - # Each TP rank listens/queries on the base_port + tp_rank. - self.side_channel_port: int = ( - envs.VLLM_NIXL_SIDE_CHANNEL_PORT - + vllm_config.parallel_config.data_parallel_rank - * vllm_config.parallel_config.tensor_parallel_size - ) - # Metadata. self.engine_id: EngineId = engine_id self.tp_rank = get_tensor_model_parallel_rank() @@ -614,8 +686,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # Set of requests that have been part of a batch, regardless of status. self._reqs_to_process: set[ReqId] = set() - # Background thread for handling new handshake requests. - self._nixl_handshake_listener_t: Optional[threading.Thread] = None + # Handshake metadata of this worker for NIXL transfers. + self.xfer_handshake_metadata: Optional[NixlAgentMetadata] = None # Background thread for initializing new NIXL handshakes. self._handshake_initiation_executor = ThreadPoolExecutor( # NIXL is not guaranteed to be thread-safe, limit 1 worker. @@ -659,34 +731,6 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int) self.xfer_stats = NixlKVConnectorStats() - @staticmethod - def _nixl_handshake_listener( - metadata: NixlAgentMetadata, - ready_event: threading.Event, - base_port: int, - tp_rank: int, - ): - """Background thread for getting new NIXL handshakes.""" - # NOTE(rob): this is a simple implementation. We will move - # to a better approach via HTTP endpoint soon. - - encoder = msgspec.msgpack.Encoder() - encoded_data = encoder.encode(metadata) - size_in_bytes = len(encoded_data) - logger.debug("Size of encoded NixlAgentMetadata: %s bytes", str(size_in_bytes)) - - # Listen for new requests for metadata. - host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST - path = make_zmq_path("tcp", host, base_port + tp_rank) - logger.debug("Starting listening on path: %s", path) - with zmq_ctx(zmq.ROUTER, path) as sock: - ready_event.set() - while True: - identity, _, msg = sock.recv_multipart() - if msg != GET_META_MSG: - logger.warning("Connection listener got unexpected message %s", msg) - sock.send_multipart((identity, b"", encoded_data)) - def _nixl_handshake( self, host: str, @@ -705,15 +749,18 @@ def _nixl_handshake( # Handshake only with the remote TP rank that current local rank will # pull from. With homogeneous TP it happens to be the same rank_i. tp_ratio = self._tp_size[self.engine_id] // remote_tp_size - p_remote_rank = self.tp_rank // tp_ratio - path = make_zmq_path("tcp", host, port + p_remote_rank) + p_remote_tp_rank = self.tp_rank // tp_ratio + path = make_zmq_path("tcp", host, port) logger.debug( - "Querying metadata on path: %s at remote rank %s", path, p_remote_rank + "Querying metadata on path: %s at remote tp rank %s", path, p_remote_tp_rank ) + # [WIP] gluo: using dp_rank 0 data (standard for disaggregated prefill-decode) is sufficient? + p_remote_dp_rank = 0 # Send query for the request. with zmq_ctx(zmq.REQ, path) as sock: - sock.send(GET_META_MSG) + msg = msgspec.msgpack.encode((GET_META_MSG, p_remote_dp_rank, p_remote_tp_rank)) + sock.send(msg) metadata_bytes = sock.recv() decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) metadata = decoder.decode(metadata_bytes) @@ -732,7 +779,7 @@ def _nixl_handshake( # Register Remote agent. remote_agent_name = self.add_remote_agent( - metadata, p_remote_rank, remote_tp_size + metadata, p_remote_tp_rank, remote_tp_size ) setup_agent_time = time.perf_counter() logger.debug( @@ -741,7 +788,7 @@ def _nixl_handshake( ) # Remote rank -> agent name. - return {p_remote_rank: remote_agent_name} + return {p_remote_tp_rank: remote_agent_name} def initialize_host_xfer_buffer(self, kv_caches: dict[str, torch.Tensor]) -> None: """ @@ -973,7 +1020,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): assert len(self.block_window_per_layer) == self.num_layers # After KV Caches registered, listen for new connections. - metadata = NixlAgentMetadata( + self.xfer_handshake_metadata = NixlAgentMetadata( engine_id=self.engine_id, agent_metadata=self.nixl_wrapper.get_agent_metadata(), kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], @@ -982,15 +1029,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): attn_backend_name=self.backend_name, kv_cache_layout=self.kv_cache_layout, ) - ready_event = threading.Event() - self._nixl_handshake_listener_t = threading.Thread( - target=self._nixl_handshake_listener, - args=(metadata, ready_event, self.side_channel_port, self.tp_rank), - daemon=True, - name="nixl_handshake_listener", - ) - self._nixl_handshake_listener_t.start() - ready_event.wait() # Wait for listener ZMQ socket to be ready. def add_remote_agent( self, @@ -1528,9 +1566,6 @@ def get_kv_connector_stats(self) -> Optional[KVConnectorStats]: def shutdown(self): """Shutdown the connector worker.""" self._handshake_initiation_executor.shutdown(wait=False) - if self._nixl_handshake_listener_t is not None: - self._nixl_handshake_listener_t.join(timeout=0) - self._nixl_handshake_listener_t = None for handles in self._recving_transfers.values(): for handle, _ in handles: self.nixl_wrapper.release_xfer_handle(handle) diff --git a/vllm/executor/uniproc_executor.py b/vllm/executor/uniproc_executor.py index 8206f23d1878..982fbc97e9d7 100644 --- a/vllm/executor/uniproc_executor.py +++ b/vllm/executor/uniproc_executor.py @@ -4,7 +4,7 @@ from concurrent.futures import Future, ThreadPoolExecutor from functools import cached_property from multiprocessing import Lock -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch import torch.distributed as dist @@ -93,6 +93,10 @@ def collective_rpc( future.set_exception(e) return [future] + def get_kv_connector_handshake_metadata(self) -> List[Optional[Dict]]: + """Get KV connector handshake metadata from all workers.""" + return self.collective_rpc("get_kv_connector_handshake_metadata") + def check_health(self) -> None: # UniProcExecutor will always be healthy as long as # it's running. diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 4826d7c589a7..9774ce80c148 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -160,6 +160,28 @@ def __init__( vllm_config, mm_registry ) + # If a KV connector is initialized for scheduler, we want to collect + # handshake metadata from all workers so the connector in the scheduler + # will have the full context + if self.scheduler.get_kv_connector() is not None: + # Collect and store KV connector xfer metadata from workers + # (after KV cache registration) + xfer_handshake_metadata = self.model_executor.get_kv_connector_handshake_metadata() + + if xfer_handshake_metadata: + # xfer_handshake_metadata is list of dicts from workers + # Each dict already has structure {dp_rank: {tp_rank: metadata}} + # Merge all worker dicts into a single dict + content: dict[int, dict[int, dict[int, Any]]] = {} + for worker_dict in xfer_handshake_metadata: + if worker_dict is not None: + # Deep merge nested dictionaries instead of overwrite + for dp_rank, tp_dict in worker_dict.items(): + if dp_rank not in content: + content[dp_rank] = {} + content[dp_rank].update(tp_dict) + self.scheduler.get_kv_connector().set_xfer_handshake_metadata(content) + # Setup batch queue for pipeline parallelism. # Batch queue for scheduled batches. This enables us to asynchronously # schedule and execute batches, and is required by pipeline parallelism diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index 064e4b2bbf18..34ed7e820c4d 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -99,6 +99,10 @@ def collective_rpc( ) -> list[Any]: raise NotImplementedError + def get_kv_connector_handshake_metadata( + self) -> list[dict[int, dict[int, dict]]]: + return self.collective_rpc("get_kv_connector_handshake_metadata") + def execute_model( self, scheduler_output: SchedulerOutput, diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 271aabb9e227..b97a629b9d7e 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -19,7 +19,9 @@ init_distributed_environment, set_custom_all_reduce, ) -from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized +from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized, + get_kv_transfer_group, + has_kv_transfer_group) from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -318,6 +320,25 @@ def determine_available_memory(self) -> int: gc.collect() return int(self.available_kv_cache_memory_bytes) + + def get_kv_connector_handshake_metadata(self) -> Optional[dict]: + """Get KV connector metadata from this worker if available.""" + + if not has_kv_transfer_group(): + return None + + connector = get_kv_transfer_group() + metadata = connector.get_handshake_metadata() + if metadata is None: + logger.warning( + "KV connector metadata is not available. " + "This may happen if the KV connector is not initialized " + "or the worker is not part of a disaggregated KV cache setup.") + return None + + tp_rank = get_tp_group().rank_in_group + dp_rank = self.vllm_config.parallel_config.data_parallel_rank_local + return {dp_rank: {tp_rank: metadata}} def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: return self.model_runner.get_kv_cache_spec() From 8fb478aecfb9a4368fb2584d52d070ae1e67372b Mon Sep 17 00:00:00 2001 From: Guan Luo Date: Thu, 9 Oct 2025 17:08:34 -0700 Subject: [PATCH 02/19] test: add test Signed-off-by: Guan Luo --- .../kv_connector/unit/test_nixl_connector.py | 62 +++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index a1f53cb25563..200feb47f79e 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -284,6 +284,68 @@ def test_prompt_less_than_block_size(): assert len(kv_connector_metadata.reqs_to_recv) == 1 assert len(scheduler_output.scheduled_new_reqs) == 0 +def test_kv_transfer_metadata(): + """Unit test for basic NixlConnector interface functionality.""" + + # Test setup, we creates a scheduler that contains a NixlConnector + # of role SCHEDULER, and expect it to be serving NixlAgentMetadata from + # all workers of the instance. + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # Create two NixlConnector of role WORKER, one is the worker of + # the scheduler (prefill), the other is a worker of decode instance. + prefill_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + # gather connector metadata from all workers, the scheduler connector + # expects metadata to be in Dict[int, Dict[int, KVConnectorHandshakeMetadata]], + # where the first key is the dp_rank, the second key is the tp_rank. + metadata = {0: {0: prefill_connector.get_handshake_metadata()}} + scheduler.get_kv_connector().set_xfer_handshake_metadata(metadata) + + # 2 Full Blocks and 1 Half Block. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + request = create_request( + request_id=1, + block_size=BLOCK_SIZE, + num_tokens=NUM_TOKENS, + do_remote_prefill=True, + ) + request_id = request.request_id + + scheduler.add_request(request) + + # Remote Prefill, triggers NixlConnectorMetadata for handshake. + scheduler_output = scheduler.schedule() + kv_connector_metadata = scheduler_output.kv_connector_metadata + assert kv_connector_metadata is not None + assert isinstance(kv_connector_metadata, NixlConnectorMetadata) + + # Decode connector will be able to create handshake with the prefill connector. + decode_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + # Here we are testing the retrieval of NIXLAgentMetadata, + # by knowing the implementation detail, we override the add_remote_agent + # to validate the metadata received is the same as the one in prefill_connector. + + received_metadata = None + def mock_add_remote_agent(self, agent_metadata: NixlAgentMetadata, remote_tp_rank: int, remote_tp_size: int): + nonlocal received_metadata + received_metadata = (agent_metadata, remote_tp_rank, remote_tp_size) + return "remote_agent" + decode_connector.connector_worker.add_remote_agent = mock_add_remote_agent + + meta = kv_connector_metadata.reqs_to_recv[request_id] + decode_connector.connector_worker._nixl_handshake(meta.remote_host, + meta.remote_port, + meta.tp_size, + meta.remote_engine_id) + assert received_metadata is not None + assert received_metadata[1] == 0 # remote_tp_rank + assert received_metadata[2] == 1 # remote_tp_size + assert metadata[0][0] == received_metadata[0] + class FakeNixlConnectorWorker(NixlConnectorWorker): REMOTE_ENGINE_ID = "remote_engine" From f5e82f5c8cf4c46f8911ea34304998d244c2c522 Mon Sep 17 00:00:00 2001 From: Guan Luo Date: Thu, 9 Oct 2025 23:34:00 -0700 Subject: [PATCH 03/19] chore: style and fix test Signed-off-by: Guan Luo --- .../kv_connector/unit/test_nixl_connector.py | 23 +++++++----- .../kv_transfer/kv_connector/v1/base.py | 3 +- .../kv_connector/v1/nixl_connector.py | 35 +++++++++++++------ vllm/v1/engine/core.py | 6 ++-- vllm/v1/executor/abstract.py | 3 +- vllm/v1/worker/gpu_worker.py | 13 ++++--- 6 files changed, 54 insertions(+), 29 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 200feb47f79e..ec862987b89d 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -284,7 +284,8 @@ def test_prompt_less_than_block_size(): assert len(kv_connector_metadata.reqs_to_recv) == 1 assert len(scheduler_output.scheduled_new_reqs) == 0 -def test_kv_transfer_metadata(): + +def test_kv_transfer_metadata(dist_init): """Unit test for basic NixlConnector interface functionality.""" # Test setup, we creates a scheduler that contains a NixlConnector @@ -323,24 +324,30 @@ def test_kv_transfer_metadata(): assert kv_connector_metadata is not None assert isinstance(kv_connector_metadata, NixlConnectorMetadata) - # Decode connector will be able to create handshake with the prefill connector. + # Decode connector will be able to create handshake with the prefill connector. decode_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) # Here we are testing the retrieval of NIXLAgentMetadata, # by knowing the implementation detail, we override the add_remote_agent # to validate the metadata received is the same as the one in prefill_connector. - + received_metadata = None - def mock_add_remote_agent(self, agent_metadata: NixlAgentMetadata, remote_tp_rank: int, remote_tp_size: int): + + def mock_add_remote_agent( + self, + agent_metadata: NixlAgentMetadata, + remote_tp_rank: int, + remote_tp_size: int, + ): nonlocal received_metadata received_metadata = (agent_metadata, remote_tp_rank, remote_tp_size) return "remote_agent" + decode_connector.connector_worker.add_remote_agent = mock_add_remote_agent meta = kv_connector_metadata.reqs_to_recv[request_id] - decode_connector.connector_worker._nixl_handshake(meta.remote_host, - meta.remote_port, - meta.tp_size, - meta.remote_engine_id) + decode_connector.connector_worker._nixl_handshake( + meta.remote_host, meta.remote_port, meta.tp_size, meta.remote_engine_id + ) assert received_metadata is not None assert received_metadata[1] == 0 # remote_tp_rank assert received_metadata[2] == 1 # remote_tp_size diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 47ab3083b5d6..e57ef39a1a90 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -82,6 +82,7 @@ class KVConnectorHandshakeMetadata(ABC): # noqa: B024 Metadata used for out of band connector handshakeandshake between P/D workers. This needs to serializeable. """ + pass @@ -450,4 +451,4 @@ def set_xfer_handshake_metadata(self, metadata: dict[int, dict[int, dict]]) -> N Args: metadata (dict): the handshake metadata to set. """ - pass + return None diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 0d2a2aa666d8..a00ed988eae0 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -89,6 +89,7 @@ _NIXL_SUPPORTED_DEVICE.update(current_platform.get_nixl_supported_devices()) +@dataclass class NixlAgentMetadata(KVConnectorHandshakeMetadata): engine_id: str agent_metadata: bytes @@ -213,7 +214,7 @@ def request_finished( ) -> tuple[bool, Optional[dict[str, Any]]]: assert self.connector_scheduler is not None return self.connector_scheduler.request_finished(request, block_ids) - + def set_xfer_handshake_metadata(self, metadata: dict[int, dict[int, dict]]) -> None: """ Set the KV connector handshake metadata for this connector. @@ -339,14 +340,18 @@ def set_xfer_handshake_metadata(self, metadata: dict[int, dict[int, dict]]) -> N Args: metadata (dict): the handshake metadata to set. """ - encoded_data = {} + encoded_data: dict[int, dict[int, bytes]] = {} for dp_rank, tp_metadata in metadata.items(): encoded_data[dp_rank] = {} for tp_rank, rank_metadata in tp_metadata.items(): encoder = msgspec.msgpack.Encoder() encoded_data[dp_rank][tp_rank] = encoder.encode(rank_metadata) - logger.debug("Dp rank %d, Tp rank %d: Size of encoded NixlAgentMetadata: %s bytes", - dp_rank, tp_rank, str(len(encoded_data[dp_rank][tp_rank]))) + logger.debug( + "Dp rank %d, Tp rank %d: encoded NixlAgentMetadata size: %s bytes", + dp_rank, + tp_rank, + str(len(encoded_data[dp_rank][tp_rank])), + ) self._encoded_xfer_handshake_metadata = encoded_data # Only start the listener when we have metadata to serve. @@ -381,12 +386,16 @@ def _nixl_handshake_listener( identity, _, msg = sock.recv_multipart() # Decode the message which contains (GET_META_MSG, rank) msg, target_dp_rank, target_tp_rank = msgspec.msgpack.decode(msg) - logger.debug("Received message for dp rank %s tp rank %s", target_dp_rank, target_tp_rank) + logger.debug( + "Received message for dp rank %s tp rank %s", + target_dp_rank, + target_tp_rank, + ) if msg != GET_META_MSG: - logger.warning( - "Connection listener got unexpected message %s", msg) + logger.warning("Connection listener got unexpected message %s", msg) sock.send_multipart( - (identity, b"", encoded_data[target_dp_rank][target_tp_rank])) + (identity, b"", encoded_data[target_dp_rank][target_tp_rank]) + ) def get_num_new_matched_tokens( self, request: "Request", num_computed_tokens: int @@ -754,12 +763,15 @@ def _nixl_handshake( logger.debug( "Querying metadata on path: %s at remote tp rank %s", path, p_remote_tp_rank ) - # [WIP] gluo: using dp_rank 0 data (standard for disaggregated prefill-decode) is sufficient? + # [WIP] gluo: using dp_rank 0 data (standard for disaggregated prefill-decode) + # is sufficient? p_remote_dp_rank = 0 # Send query for the request. with zmq_ctx(zmq.REQ, path) as sock: - msg = msgspec.msgpack.encode((GET_META_MSG, p_remote_dp_rank, p_remote_tp_rank)) + msg = msgspec.msgpack.encode( + (GET_META_MSG, p_remote_dp_rank, p_remote_tp_rank) + ) sock.send(msg) metadata_bytes = sock.recv() decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) @@ -1448,7 +1460,8 @@ def _read_blocks( # workers will issue xfers to parts of the P worker remote kv caches. # Get descs ids. - local_block_descs_ids: np.ndarray + local_block_descs_ids: np.ndar621 + remote_block_descs_ids: np.ndarray if not self.block_window_per_layer: # Default case: assume global attention diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 9774ce80c148..4d473ad250e8 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -162,11 +162,13 @@ def __init__( # If a KV connector is initialized for scheduler, we want to collect # handshake metadata from all workers so the connector in the scheduler - # will have the full context + # will have the full context if self.scheduler.get_kv_connector() is not None: # Collect and store KV connector xfer metadata from workers # (after KV cache registration) - xfer_handshake_metadata = self.model_executor.get_kv_connector_handshake_metadata() + xfer_handshake_metadata = ( + self.model_executor.get_kv_connector_handshake_metadata() + ) if xfer_handshake_metadata: # xfer_handshake_metadata is list of dicts from workers diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index 34ed7e820c4d..4964a952dd3c 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -99,8 +99,7 @@ def collective_rpc( ) -> list[Any]: raise NotImplementedError - def get_kv_connector_handshake_metadata( - self) -> list[dict[int, dict[int, dict]]]: + def get_kv_connector_handshake_metadata(self) -> list[dict[int, dict[int, dict]]]: return self.collective_rpc("get_kv_connector_handshake_metadata") def execute_model( diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index b97a629b9d7e..b4d05d395e5d 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -19,9 +19,11 @@ init_distributed_environment, set_custom_all_reduce, ) -from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized, - get_kv_transfer_group, - has_kv_transfer_group) +from vllm.distributed.kv_transfer import ( + ensure_kv_transfer_initialized, + get_kv_transfer_group, + has_kv_transfer_group, +) from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -320,7 +322,7 @@ def determine_available_memory(self) -> int: gc.collect() return int(self.available_kv_cache_memory_bytes) - + def get_kv_connector_handshake_metadata(self) -> Optional[dict]: """Get KV connector metadata from this worker if available.""" @@ -333,7 +335,8 @@ def get_kv_connector_handshake_metadata(self) -> Optional[dict]: logger.warning( "KV connector metadata is not available. " "This may happen if the KV connector is not initialized " - "or the worker is not part of a disaggregated KV cache setup.") + "or the worker is not part of a disaggregated KV cache setup." + ) return None tp_rank = get_tp_group().rank_in_group From aa41ea573083e82a757bf21bec0ee50fd8f05b1d Mon Sep 17 00:00:00 2001 From: Guan Luo Date: Fri, 10 Oct 2025 01:27:25 -0700 Subject: [PATCH 04/19] chore: fix pre-commit Signed-off-by: Guan Luo --- vllm/executor/uniproc_executor.py | 4 ++-- vllm/v1/engine/core.py | 7 ++++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/executor/uniproc_executor.py b/vllm/executor/uniproc_executor.py index 982fbc97e9d7..43ffabe474dc 100644 --- a/vllm/executor/uniproc_executor.py +++ b/vllm/executor/uniproc_executor.py @@ -4,7 +4,7 @@ from concurrent.futures import Future, ThreadPoolExecutor from functools import cached_property from multiprocessing import Lock -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Optional, Union import torch import torch.distributed as dist @@ -93,7 +93,7 @@ def collective_rpc( future.set_exception(e) return [future] - def get_kv_connector_handshake_metadata(self) -> List[Optional[Dict]]: + def get_kv_connector_handshake_metadata(self) -> list[Optional[dict]]: """Get KV connector handshake metadata from all workers.""" return self.collective_rpc("get_kv_connector_handshake_metadata") diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 4d473ad250e8..a3a2634fa5db 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -163,7 +163,8 @@ def __init__( # If a KV connector is initialized for scheduler, we want to collect # handshake metadata from all workers so the connector in the scheduler # will have the full context - if self.scheduler.get_kv_connector() is not None: + kv_connector = self.scheduler.get_kv_connector() + if kv_connector is not None: # Collect and store KV connector xfer metadata from workers # (after KV cache registration) xfer_handshake_metadata = ( @@ -182,7 +183,7 @@ def __init__( if dp_rank not in content: content[dp_rank] = {} content[dp_rank].update(tp_dict) - self.scheduler.get_kv_connector().set_xfer_handshake_metadata(content) + kv_connector.set_xfer_handshake_metadata(content) # Setup batch queue for pipeline parallelism. # Batch queue for scheduled batches. This enables us to asynchronously @@ -199,7 +200,7 @@ def __init__( self.request_block_hasher: Optional[Callable[[Request], list[BlockHash]]] = None if ( self.vllm_config.cache_config.enable_prefix_caching - or self.scheduler.get_kv_connector() is not None + or kv_connector is not None ): block_size = vllm_config.cache_config.block_size caching_hash_fn = get_hash_fn_by_name( From d3abc468b08352f7f90e9da427823bed272b2697 Mon Sep 17 00:00:00 2001 From: Guan Luo Date: Fri, 10 Oct 2025 02:36:00 -0700 Subject: [PATCH 05/19] fix: fix test Signed-off-by: Guan Luo --- .../kv_connector/unit/test_nixl_connector.py | 48 ++++++++++++------- 1 file changed, 31 insertions(+), 17 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index ec862987b89d..af1e9486acde 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -296,14 +296,31 @@ def test_kv_transfer_metadata(dist_init): # Create two NixlConnector of role WORKER, one is the worker of # the scheduler (prefill), the other is a worker of decode instance. + + # Prefill connector will register KV cache to populate proper handshake + # metadata. prefill_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape( + num_blocks=2, block_size=16, num_kv_heads=4, head_size=64 + ) + shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) + unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) + kv_caches = { + "layer0": shared_tensor, + "layer1": unique_tensor, + "layer2": shared_tensor, + } + prefill_connector.register_kv_caches(kv_caches) + + # Simulate EngineCore initialization that would # gather connector metadata from all workers, the scheduler connector # expects metadata to be in Dict[int, Dict[int, KVConnectorHandshakeMetadata]], # where the first key is the dp_rank, the second key is the tp_rank. metadata = {0: {0: prefill_connector.get_handshake_metadata()}} scheduler.get_kv_connector().set_xfer_handshake_metadata(metadata) - # 2 Full Blocks and 1 Half Block. + # Simulate a request that finishes prefill, which returns + # corresponding NixlConnectorMetadata for decode instance. BLOCK_SIZE = vllm_config.cache_config.block_size NUM_EXTERNAL_FULL_BLOCKS = 2 NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) @@ -312,28 +329,23 @@ def test_kv_transfer_metadata(dist_init): request_id=1, block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS, - do_remote_prefill=True, + do_remote_decode=True, ) - request_id = request.request_id - - scheduler.add_request(request) - - # Remote Prefill, triggers NixlConnectorMetadata for handshake. - scheduler_output = scheduler.schedule() - kv_connector_metadata = scheduler_output.kv_connector_metadata - assert kv_connector_metadata is not None - assert isinstance(kv_connector_metadata, NixlConnectorMetadata) + request.status = RequestStatus.FINISHED_LENGTH_CAPPED + delay, kv_connector_metadata = scheduler.get_kv_connector().request_finished( + request, [0, 1, 2] + ) + assert delay # Decode connector will be able to create handshake with the prefill connector. decode_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) - # Here we are testing the retrieval of NIXLAgentMetadata, - # by knowing the implementation detail, we override the add_remote_agent - # to validate the metadata received is the same as the one in prefill_connector. + # Here we are testing the retrieval of NIXLAgentMetadata. + # Knowing the implementation detail, we override the add_remote_agent + # to validate the metadata received is the same as the one in prefill_connector. received_metadata = None def mock_add_remote_agent( - self, agent_metadata: NixlAgentMetadata, remote_tp_rank: int, remote_tp_size: int, @@ -344,9 +356,11 @@ def mock_add_remote_agent( decode_connector.connector_worker.add_remote_agent = mock_add_remote_agent - meta = kv_connector_metadata.reqs_to_recv[request_id] decode_connector.connector_worker._nixl_handshake( - meta.remote_host, meta.remote_port, meta.tp_size, meta.remote_engine_id + kv_connector_metadata["remote_host"], + kv_connector_metadata["remote_port"], + kv_connector_metadata["tp_size"], + kv_connector_metadata["remote_engine_id"], ) assert received_metadata is not None assert received_metadata[1] == 0 # remote_tp_rank From 813dab1c07ad5db1320a035f1e8377d642bcade4 Mon Sep 17 00:00:00 2001 From: Guan Luo Date: Fri, 10 Oct 2025 04:02:12 -0700 Subject: [PATCH 06/19] fix: fix NIXL handshake listener cleanup Signed-off-by: Guan Luo --- .../kv_connector/unit/test_nixl_connector.py | 21 ++++++++++++++++--- .../kv_connector/v1/nixl_connector.py | 20 +++++++++++++++--- 2 files changed, 35 insertions(+), 6 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index af1e9486acde..adcdad03b59a 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -28,6 +28,7 @@ NixlAgentMetadata, NixlConnector, NixlConnectorMetadata, + NixlConnectorScheduler, NixlConnectorWorker, NixlKVConnectorStats, ) @@ -317,7 +318,8 @@ def test_kv_transfer_metadata(dist_init): # expects metadata to be in Dict[int, Dict[int, KVConnectorHandshakeMetadata]], # where the first key is the dp_rank, the second key is the tp_rank. metadata = {0: {0: prefill_connector.get_handshake_metadata()}} - scheduler.get_kv_connector().set_xfer_handshake_metadata(metadata) + scheduler_connector = scheduler.get_kv_connector() + scheduler_connector.set_xfer_handshake_metadata(metadata) # Simulate a request that finishes prefill, which returns # corresponding NixlConnectorMetadata for decode instance. @@ -367,6 +369,9 @@ def mock_add_remote_agent( assert received_metadata[2] == 1 # remote_tp_size assert metadata[0][0] == received_metadata[0] + # Need to shutdown the background thread to release NIXL side channel port + scheduler_connector.shutdown() + class FakeNixlConnectorWorker(NixlConnectorWorker): REMOTE_ENGINE_ID = "remote_engine" @@ -958,6 +963,8 @@ def _run_abort_timeout_test(llm_kwargs: dict, timeout: int): _ = llm.generate([f"What is the capital of France? {padding}"], sampling_params) # Request-0 times out and is cleared! assert "0" not in req_to_blocks + # Need to shutdown the background thread to release NIXL side channel port + llm.llm_engine.engine_core.shutdown() def test_register_kv_caches(dist_init): @@ -1127,12 +1134,15 @@ def test_shutdown_cleans_up_resources(dist_init): """Test that shutdown() properly cleans up all resources.""" vllm_config = create_vllm_config() + scheduler = NixlConnectorScheduler( + vllm_config, vllm_config.kv_transfer_config.engine_id + ) worker = NixlConnectorWorker(vllm_config, vllm_config.kv_transfer_config.engine_id) nixl_wrapper = worker.nixl_wrapper with ( patch.object(worker, "_handshake_initiation_executor") as mock_exec, - patch.object(worker, "_nixl_handshake_listener_t") as mock_listener, + patch.object(scheduler, "_nixl_handshake_listener_t") as mock_listener, patch.object(nixl_wrapper, "release_xfer_handle") as mock_rel_xfer, patch.object(nixl_wrapper, "release_dlist_handle") as mock_rel_dlist, patch.object(nixl_wrapper, "remove_remote_agent") as mock_rem_agent, @@ -1151,7 +1161,12 @@ def test_shutdown_cleans_up_resources(dist_init): worker.shutdown() mock_exec.shutdown.assert_called_with(wait=False) - mock_listener.join.assert_called_once_with(timeout=0) + + # Same sequence on scheduler.shutdown() + scheduler.shutdown() + scheduler.shutdown() + scheduler.shutdown() + mock_listener.join.assert_called_once() mock_rel_xfer.assert_called_once_with(123) assert mock_rel_dlist.call_count == 2 diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index a00ed988eae0..f89eb8e66658 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -315,6 +315,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # Background thread for handling new handshake requests. self._nixl_handshake_listener_t: Optional[threading.Thread] = None self._encoded_xfer_handshake_metadata: dict[int, dict[int, Any]] = {} + self._stop_event = threading.Event() # Requests that need to start recv/send. # New requests are added by update_state_after_alloc in @@ -329,8 +330,9 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self._reqs_not_processed: set[ReqId] = set() def shutdown(self): + self._stop_event.set() if self._nixl_handshake_listener_t is not None: - self._nixl_handshake_listener_t.join(timeout=0) + self._nixl_handshake_listener_t.join() self._nixl_handshake_listener_t = None def set_xfer_handshake_metadata(self, metadata: dict[int, dict[int, dict]]) -> None: @@ -359,7 +361,12 @@ def set_xfer_handshake_metadata(self, metadata: dict[int, dict[int, dict]]) -> N ready_event = threading.Event() self._nixl_handshake_listener_t = threading.Thread( target=self._nixl_handshake_listener, - args=(encoded_data, ready_event, self.side_channel_port), + args=( + encoded_data, + ready_event, + self._stop_event, + self.side_channel_port, + ), daemon=True, name="nixl_handshake_listener", ) @@ -370,6 +377,7 @@ def set_xfer_handshake_metadata(self, metadata: dict[int, dict[int, dict]]) -> N def _nixl_handshake_listener( encoded_data: dict[int, dict[int, Any]], ready_event: threading.Event, + stop_event: threading.Event, port: int, ): """Background thread for getting new NIXL handshakes.""" @@ -381,9 +389,15 @@ def _nixl_handshake_listener( path = make_zmq_path("tcp", host, port) logger.debug("Starting listening on path: %s", path) with zmq_ctx(zmq.ROUTER, path) as sock: + sock.setsockopt(zmq.RCVTIMEO, 1000) ready_event.set() while True: - identity, _, msg = sock.recv_multipart() + try: + identity, _, msg = sock.recv_multipart() + except zmq.Again: + if stop_event.is_set(): + break + continue # Decode the message which contains (GET_META_MSG, rank) msg, target_dp_rank, target_tp_rank = msgspec.msgpack.decode(msg) logger.debug( From 009ec48ba9b63c492777408d74116e9c87ee7e29 Mon Sep 17 00:00:00 2001 From: Guan Luo Date: Fri, 10 Oct 2025 17:53:08 -0700 Subject: [PATCH 07/19] fix: fix device ID use for NIXL memory registration Signed-off-by: Guan Luo --- .../kv_connector/v1/nixl_connector.py | 26 ++++++++++++++----- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index f89eb8e66658..111f606583b9 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -94,6 +94,7 @@ class NixlAgentMetadata(KVConnectorHandshakeMetadata): engine_id: str agent_metadata: bytes kv_caches_base_addr: list[int] + device_id: list[int] num_blocks: int block_lens: list[int] attn_backend_name: str @@ -684,6 +685,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # Map of engine_id -> kv_caches_base_addr. For TP case, each local # rank will still only pull from a single remote TP worker. self.kv_caches_base_addr: dict[EngineId, list[int]] = {} + self.device_id: dict[EngineId, list[int]] = {} # Number of NIXL regions. Currently one region per cache # (so 1 per layer for MLA, otherwise 2 per layer) @@ -903,6 +905,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): caches_data = [] # With hybrid allocator, layers can share a kv cache tensor seen_base_addresses = [] + # Map from address to device ID + seen_addresses_device_id = [] # Note(tms): I modified this from the original region setup code. # K and V are now in different regions. Advantage is that we can @@ -926,6 +930,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): continue seen_base_addresses.append(base_addr) + seen_addresses_device_id.append(cache.get_device()) curr_tensor_size_bytes = cache.numel() * cache.element_size() if tensor_size_bytes is None: @@ -949,7 +954,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): "All kv cache tensors must have the same size" ) caches_data.append( - (base_addr, curr_tensor_size_bytes, self.tp_rank, "") + (base_addr, curr_tensor_size_bytes, cache.get_device(), "") ) logger.debug( @@ -959,6 +964,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): assert self.num_blocks != 0 self.kv_caches_base_addr[self.engine_id] = seen_base_addresses + self.device_id[self.engine_id] = seen_addresses_device_id self.num_regions = len(caches_data) self.num_layers = len(xfer_buffers.keys()) @@ -985,7 +991,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # Register local/src descr for NIXL xfer. blocks_data = [] - for i, base_addr in enumerate(seen_base_addresses): + for i, (base_addr, device_id) in enumerate( + zip(seen_base_addresses, seen_addresses_device_id) + ): kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) # NOTE With heter-TP, more blocks are prepared than what are # needed as self.num_blocks >= nixl_agent_meta.num_blocks. We @@ -996,7 +1004,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): block_offset = block_id * self.block_len_per_layer[i] addr = base_addr + block_offset # (addr, len, device id) - blocks_data.append((addr, kv_block_len, self.tp_rank)) + blocks_data.append((addr, kv_block_len, device_id)) if self._use_flashinfer: # Separate and interleave K/V regions to maintain the same @@ -1007,7 +1015,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): addr = base_addr + block_offset # Register addresses for V cache (K registered first). v_addr = addr + kv_block_len - blocks_data.append((v_addr, kv_block_len, self.tp_rank)) + blocks_data.append((v_addr, kv_block_len, device_id)) logger.debug( "Created %s blocks for src engine %s and rank %s", len(blocks_data), @@ -1050,6 +1058,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): engine_id=self.engine_id, agent_metadata=self.nixl_wrapper.get_agent_metadata(), kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], + device_id=self.device_id[self.engine_id], num_blocks=self.num_blocks, block_lens=self.block_len_per_layer, attn_backend_name=self.backend_name, @@ -1175,10 +1184,13 @@ def add_remote_agent( # P KV cache along kv_head dim, of D worker's kv_head size (D>P). # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..]. self.kv_caches_base_addr[engine_id] = nixl_agent_meta.kv_caches_base_addr + self.device_id[engine_id] = nixl_agent_meta.device_id assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer) # Register all remote blocks, but only the corresponding kv heads. - for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): + for i, (base_addr, device_id) in enumerate( + zip(nixl_agent_meta.kv_caches_base_addr, nixl_agent_meta.device_id) + ): kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) rank_offset = ( self.tp_rank % tp_ratio * kv_block_len @@ -1192,7 +1204,7 @@ def add_remote_agent( # self.block_len == remote_block_len//tp_ratio bytes. addr = base_addr + block_offset + rank_offset # (addr, len, device id) - blocks_data.append((addr, kv_block_len, remote_tp_rank)) + blocks_data.append((addr, kv_block_len, device_id)) if self._use_flashinfer: # With FlashInfer index V separately to allow head splitting. @@ -1200,7 +1212,7 @@ def add_remote_agent( block_offset = block_id * nixl_agent_meta.block_lens[i] addr = base_addr + block_offset + rank_offset v_addr = addr + nixl_agent_meta.block_lens[i] // 2 - blocks_data.append((v_addr, kv_block_len, remote_tp_rank)) + blocks_data.append((v_addr, kv_block_len, device_id)) logger.debug( "Created %s blocks for dst engine %s with remote rank %s and local rank %s", From e52448ff27abea310b04204957834a4c0513916f Mon Sep 17 00:00:00 2001 From: Guan Luo Date: Mon, 13 Oct 2025 19:25:14 -0700 Subject: [PATCH 08/19] chore: address comment Signed-off-by: Guan Luo --- vllm/distributed/kv_transfer/kv_connector/v1/base.py | 4 ++-- .../kv_transfer/kv_connector/v1/nixl_connector.py | 6 +++--- vllm/executor/uniproc_executor.py | 2 +- vllm/v1/worker/gpu_worker.py | 5 ++--- 4 files changed, 8 insertions(+), 9 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 9bc91854966e..c74bfba799c6 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -80,7 +80,7 @@ class KVConnectorRole(enum.Enum): class KVConnectorHandshakeMetadata(ABC): # noqa: B024 """ - Metadata used for out of band connector handshakeandshake between + Metadata used for out of band connector handshake between P/D workers. This needs to serializeable. """ @@ -281,7 +281,7 @@ def get_kv_connector_stats(self) -> Optional["KVConnectorStats"]: """ return None - def get_handshake_metadata(self) -> Optional[KVConnectorHandshakeMetadata]: + def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None: """ Get the KVConnector handshake metadata for this connector. This metadata is used for out-of-band connector handshake diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index ff1aaf4c3e6e..28f312109ec3 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -287,7 +287,7 @@ def shutdown(self): if self.connector_scheduler is not None: self.connector_scheduler.shutdown() - def get_handshake_metadata(self) -> Optional[KVConnectorHandshakeMetadata]: + def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None: """ Get the KVConnector handshake metadata for this connector. This metadata is used for out-of-band connector handshake @@ -314,7 +314,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): logger.info("Initializing NIXL Scheduler %s", engine_id) # Background thread for handling new handshake requests. - self._nixl_handshake_listener_t: Optional[threading.Thread] = None + self._nixl_handshake_listener_t: threading.Thread | None = None self._encoded_xfer_handshake_metadata: dict[int, dict[int, Any]] = {} self._stop_event = threading.Event() @@ -1488,7 +1488,7 @@ def _read_blocks( # workers will issue xfers to parts of the P worker remote kv caches. # Get descs ids. - local_block_descs_ids: np.ndar621 + local_block_descs_ids: np.ndarray remote_block_descs_ids: np.ndarray if not self.block_window_per_layer: diff --git a/vllm/executor/uniproc_executor.py b/vllm/executor/uniproc_executor.py index 103063411531..011b2b409428 100644 --- a/vllm/executor/uniproc_executor.py +++ b/vllm/executor/uniproc_executor.py @@ -86,7 +86,7 @@ def collective_rpc( future.set_exception(e) return [future] - def get_kv_connector_handshake_metadata(self) -> list[Optional[dict]]: + def get_kv_connector_handshake_metadata(self) -> list[dict | None]: """Get KV connector handshake metadata from all workers.""" return self.collective_rpc("get_kv_connector_handshake_metadata") diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 903d38e3b587..3adb7fca5456 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -324,15 +324,14 @@ def determine_available_memory(self) -> int: return int(self.available_kv_cache_memory_bytes) - def get_kv_connector_handshake_metadata(self) -> Optional[dict]: + def get_kv_connector_handshake_metadata(self) -> dict | None: """Get KV connector metadata from this worker if available.""" if not has_kv_transfer_group(): return None connector = get_kv_transfer_group() - metadata = connector.get_handshake_metadata() - if metadata is None: + if (metadata := connector.get_handshake_metadata()) is None: logger.warning( "KV connector metadata is not available. " "This may happen if the KV connector is not initialized " From fa8357ff10aa8d9c6d4a69674c8a33c549a496a3 Mon Sep 17 00:00:00 2001 From: Guan Luo <41310872+GuanLuo@users.noreply.github.com> Date: Mon, 20 Oct 2025 12:01:20 +0800 Subject: [PATCH 09/19] fix: fix test and CPU case Signed-off-by: Guan Luo <41310872+GuanLuo@users.noreply.github.com> --- tests/v1/kv_connector/unit/test_nixl_connector.py | 4 ++++ .../kv_transfer/kv_connector/v1/nixl_connector.py | 10 ++++++++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index a46c0ff10d89..51f2bc1b5c29 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -291,6 +291,8 @@ def test_kv_transfer_metadata(dist_init): # of role SCHEDULER, and expect it to be serving NixlAgentMetadata from # all workers of the instance. vllm_config = create_vllm_config() + # in case the test runs on non-GPU machine + vllm_config.kv_transfer_config.kv_buffer_device = "cpu" scheduler = create_scheduler(vllm_config) # Create two NixlConnector of role WORKER, one is the worker of @@ -398,6 +400,7 @@ def _nixl_handshake( engine_id=self.REMOTE_ENGINE_ID, agent_metadata=FakeNixlWrapper.AGENT_METADATA, kv_caches_base_addr=[0], + device_id=[0], num_blocks=1, block_lens=self.block_len_per_layer, attn_backend_name=self.backend_name, @@ -644,6 +647,7 @@ def test_handshake_fails_on_kv_cache_layout_mismatch(self, dist_init): engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID, agent_metadata=FakeNixlWrapper.AGENT_METADATA, kv_caches_base_addr=[0], + device_id=[0], num_blocks=1, block_lens=worker.block_len_per_layer, attn_backend_name=worker.backend_name, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 518db33b155a..4660cb3f9d1c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -962,7 +962,10 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): continue seen_base_addresses.append(base_addr) - seen_addresses_device_id.append(cache.get_device()) + # Need to make sure the device ID is non-negative for NIXL, + # Torch uses -1 to indicate CPU tensors while NIXL uses explicit + # memory type. + seen_addresses_device_id.append(max(cache.get_device(), 0)) curr_tensor_size_bytes = cache.numel() * cache.element_size() if tensor_size_bytes is None: @@ -985,8 +988,11 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): assert tensor_size_bytes == curr_tensor_size_bytes, ( "All kv cache tensors must have the same size" ) + # Need to make sure the device ID is non-negative for NIXL, + # Torch uses -1 to indicate CPU tensors while NIXL uses explicit + # memory type. caches_data.append( - (base_addr, curr_tensor_size_bytes, cache.get_device(), "") + (base_addr, curr_tensor_size_bytes, max(cache.get_device(), 0), "") ) logger.debug( From bb9f57767287bae5b5a338850d01cb0095fbd8c3 Mon Sep 17 00:00:00 2001 From: Guan Luo <41310872+GuanLuo@users.noreply.github.com> Date: Wed, 22 Oct 2025 12:14:15 +0800 Subject: [PATCH 10/19] chore: address comment Signed-off-by: Guan Luo <41310872+GuanLuo@users.noreply.github.com> --- .../kv_connector/unit/test_nixl_connector.py | 6 +-- .../kv_transfer/kv_connector/v1/base.py | 4 +- .../kv_connector/v1/nixl_connector.py | 48 +++++++++---------- vllm/executor/uniproc_executor.py | 4 -- vllm/v1/engine/core.py | 10 ++-- vllm/v1/executor/abstract.py | 3 +- vllm/v1/worker/gpu_worker.py | 10 ++-- 7 files changed, 36 insertions(+), 49 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 6c7df8712b51..3ea9ceda3755 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -284,7 +284,7 @@ def test_prompt_less_than_block_size(): assert len(scheduler_output.scheduled_new_reqs) == 0 -def test_kv_transfer_metadata(dist_init): +def test_kv_transfer_handshake(dist_init): """Unit test for basic NixlConnector interface functionality.""" # Test setup, we creates a scheduler that contains a NixlConnector @@ -315,9 +315,9 @@ def test_kv_transfer_metadata(dist_init): # Simulate EngineCore initialization that would # gather connector metadata from all workers, the scheduler connector - # expects metadata to be in Dict[int, Dict[int, KVConnectorHandshakeMetadata]], + # expects metadata to be in dict[int, KVConnectorHandshakeMetadata], # where the first key is the dp_rank, the second key is the tp_rank. - metadata = {0: {0: prefill_connector.get_handshake_metadata()}} + metadata = {0: prefill_connector.get_handshake_metadata()} scheduler_connector = scheduler.get_kv_connector() scheduler_connector.set_xfer_handshake_metadata(metadata) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 6bdea8e25bee..ab04b5499b01 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -453,11 +453,11 @@ def build_kv_connector_stats( """ return None - def set_xfer_handshake_metadata(self, metadata: dict[int, dict[int, dict]]) -> None: + def set_xfer_handshake_metadata(self, metadata: dict[int, KVConnectorHandshakeMetadata]) -> None: """ Set the KV connector handshake metadata for this connector. Args: - metadata (dict): the handshake metadata to set. + metadata (KVConnectorHandshakeMetadata): the handshake metadata to set. """ return None diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 598b56eef847..a54b2d5446ba 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -216,7 +216,7 @@ def request_finished( assert self.connector_scheduler is not None return self.connector_scheduler.request_finished(request, block_ids) - def set_xfer_handshake_metadata(self, metadata: dict[int, dict[int, dict]]) -> None: + def set_xfer_handshake_metadata(self, metadata: dict[int, KVConnectorHandshakeMetadata]) -> None: """ Set the KV connector handshake metadata for this connector. @@ -346,25 +346,27 @@ def shutdown(self): self._nixl_handshake_listener_t.join() self._nixl_handshake_listener_t = None - def set_xfer_handshake_metadata(self, metadata: dict[int, dict[int, dict]]) -> None: + def set_xfer_handshake_metadata(self, metadata: dict[int, KVConnectorHandshakeMetadata]) -> None: """ Set the KV connector handshake metadata for this connector. Args: metadata (dict): the handshake metadata to set. """ - encoded_data: dict[int, dict[int, bytes]] = {} - for dp_rank, tp_metadata in metadata.items(): - encoded_data[dp_rank] = {} - for tp_rank, rank_metadata in tp_metadata.items(): - encoder = msgspec.msgpack.Encoder() - encoded_data[dp_rank][tp_rank] = encoder.encode(rank_metadata) - logger.debug( - "Dp rank %d, Tp rank %d: encoded NixlAgentMetadata size: %s bytes", - dp_rank, - tp_rank, - str(len(encoded_data[dp_rank][tp_rank])), + encoded_data: dict[int, bytes] = {} + encoder = msgspec.msgpack.Encoder() + for tp_rank, rank_metadata in metadata.items(): + if not isinstance(rank_metadata, NixlAgentMetadata): + raise ValueError( + "NixlConnectorScheduler expects NixlAgentMetadata for " + "handshake metadata." ) + encoded_data[tp_rank] = encoder.encode(rank_metadata) + logger.debug( + "Tp rank %d: encoded NixlAgentMetadata size: %s bytes", + tp_rank, + str(len(encoded_data[tp_rank])), + ) self._encoded_xfer_handshake_metadata = encoded_data # Only start the listener when we have metadata to serve. @@ -386,7 +388,7 @@ def set_xfer_handshake_metadata(self, metadata: dict[int, dict[int, dict]]) -> N @staticmethod def _nixl_handshake_listener( - encoded_data: dict[int, dict[int, Any]], + encoded_data: dict[int, Any], ready_event: threading.Event, stop_event: threading.Event, port: int, @@ -410,16 +412,15 @@ def _nixl_handshake_listener( break continue # Decode the message which contains (GET_META_MSG, rank) - msg, target_dp_rank, target_tp_rank = msgspec.msgpack.decode(msg) + msg, target_tp_rank = msgspec.msgpack.decode(msg) logger.debug( - "Received message for dp rank %s tp rank %s", - target_dp_rank, + "Received message for tp rank %s", target_tp_rank, ) if msg != GET_META_MSG: logger.warning("Connection listener got unexpected message %s", msg) sock.send_multipart( - (identity, b"", encoded_data[target_dp_rank][target_tp_rank]) + (identity, b"", encoded_data[target_tp_rank]) ) def get_num_new_matched_tokens( @@ -880,16 +881,13 @@ def _nixl_handshake( p_remote_rank = self.kv_topo.get_target_remote_rank(remote_tp_size) path = make_zmq_path("tcp", host, port) logger.debug( - "Querying metadata on path: %s at remote tp rank %s", path, p_remote_tp_rank + "Querying metadata on path: %s at remote tp rank %s", path, p_remote_rank ) - # [WIP] gluo: using dp_rank 0 data (standard for disaggregated prefill-decode) - # is sufficient? - p_remote_dp_rank = 0 # Send query for the request. with zmq_ctx(zmq.REQ, path) as sock: msg = msgspec.msgpack.encode( - (GET_META_MSG, p_remote_dp_rank, p_remote_tp_rank) + (GET_META_MSG, p_remote_rank) ) # Set receive timeout to 5 seconds to avoid hanging on dead server sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds @@ -912,7 +910,7 @@ def _nixl_handshake( # Register Remote agent. remote_agent_name = self.add_remote_agent( - metadata, p_remote_tp_rank, remote_tp_size + metadata, p_remote_rank, remote_tp_size ) setup_agent_time = time.perf_counter() logger.debug( @@ -921,7 +919,7 @@ def _nixl_handshake( ) # Remote rank -> agent name. - return {p_remote_tp_rank: remote_agent_name} + return {p_remote_rank: remote_agent_name} def initialize_host_xfer_buffer(self, kv_caches: dict[str, torch.Tensor]) -> None: """ diff --git a/vllm/executor/uniproc_executor.py b/vllm/executor/uniproc_executor.py index de22331dc6e9..6a1838d3df74 100644 --- a/vllm/executor/uniproc_executor.py +++ b/vllm/executor/uniproc_executor.py @@ -87,10 +87,6 @@ def collective_rpc( future.set_exception(e) return [future] - def get_kv_connector_handshake_metadata(self) -> list[dict | None]: - """Get KV connector handshake metadata from all workers.""" - return self.collective_rpc("get_kv_connector_handshake_metadata") - def check_health(self) -> None: # UniProcExecutor will always be healthy as long as # it's running. diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index e4c32cb2b193..fd965a1b5f7d 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -182,16 +182,12 @@ def __init__( if xfer_handshake_metadata: # xfer_handshake_metadata is list of dicts from workers - # Each dict already has structure {dp_rank: {tp_rank: metadata}} + # Each dict already has structure {tp_rank: metadata} # Merge all worker dicts into a single dict - content: dict[int, dict[int, dict[int, Any]]] = {} + content: dict[int, Any] = {} for worker_dict in xfer_handshake_metadata: if worker_dict is not None: - # Deep merge nested dictionaries instead of overwrite - for dp_rank, tp_dict in worker_dict.items(): - if dp_rank not in content: - content[dp_rank] = {} - content[dp_rank].update(tp_dict) + content.update(worker_dict) kv_connector.set_xfer_handshake_metadata(content) # Setup batch queue for pipeline parallelism. diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index 1eda2f01aa8a..2efae7a11350 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -18,6 +18,7 @@ from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput +from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorHandshakeMetadata FailureCallback = Callable[[], None] @@ -100,7 +101,7 @@ def collective_rpc( ) -> list[Any]: raise NotImplementedError - def get_kv_connector_handshake_metadata(self) -> list[dict[int, dict[int, dict]]]: + def get_kv_connector_handshake_metadata(self) -> list[dict[int, KVConnectorHandshakeMetadata]]: return self.collective_rpc("get_kv_connector_handshake_metadata") def execute_model( diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 255dd92188ea..107812bc9e36 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -332,17 +332,13 @@ def get_kv_connector_handshake_metadata(self) -> dict | None: return None connector = get_kv_transfer_group() + # Return None for connectors that don't need to exchange handshake + # metadata across workers. if (metadata := connector.get_handshake_metadata()) is None: - logger.warning( - "KV connector metadata is not available. " - "This may happen if the KV connector is not initialized " - "or the worker is not part of a disaggregated KV cache setup." - ) return None tp_rank = get_tp_group().rank_in_group - dp_rank = self.vllm_config.parallel_config.data_parallel_rank_local - return {dp_rank: {tp_rank: metadata}} + return {tp_rank: metadata} def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: return self.model_runner.get_kv_cache_spec() From c49b0bb5e98632871c432d9a86ffe85748b0d675 Mon Sep 17 00:00:00 2001 From: Guan Luo <41310872+GuanLuo@users.noreply.github.com> Date: Wed, 22 Oct 2025 12:56:10 +0800 Subject: [PATCH 11/19] style: style Signed-off-by: Guan Luo <41310872+GuanLuo@users.noreply.github.com> --- .../kv_transfer/kv_connector/v1/base.py | 4 +++- .../kv_connector/v1/nixl_connector.py | 19 +++++++++---------- vllm/v1/executor/abstract.py | 8 ++++++-- 3 files changed, 18 insertions(+), 13 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index ab04b5499b01..8412638268b7 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -453,7 +453,9 @@ def build_kv_connector_stats( """ return None - def set_xfer_handshake_metadata(self, metadata: dict[int, KVConnectorHandshakeMetadata]) -> None: + def set_xfer_handshake_metadata( + self, metadata: dict[int, KVConnectorHandshakeMetadata] + ) -> None: """ Set the KV connector handshake metadata for this connector. diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index d90e6a3e45f7..817a40bf0f79 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -216,7 +216,9 @@ def request_finished( assert self.connector_scheduler is not None return self.connector_scheduler.request_finished(request, block_ids) - def set_xfer_handshake_metadata(self, metadata: dict[int, KVConnectorHandshakeMetadata]) -> None: + def set_xfer_handshake_metadata( + self, metadata: dict[int, KVConnectorHandshakeMetadata] + ) -> None: """ Set the KV connector handshake metadata for this connector. @@ -325,7 +327,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # Background thread for handling new handshake requests. self._nixl_handshake_listener_t: threading.Thread | None = None - self._encoded_xfer_handshake_metadata: dict[int, dict[int, Any]] = {} + self._encoded_xfer_handshake_metadata: dict[int, Any] = {} self._stop_event = threading.Event() # Requests that need to start recv/send. @@ -346,7 +348,9 @@ def shutdown(self): self._nixl_handshake_listener_t.join() self._nixl_handshake_listener_t = None - def set_xfer_handshake_metadata(self, metadata: dict[int, KVConnectorHandshakeMetadata]) -> None: + def set_xfer_handshake_metadata( + self, metadata: dict[int, KVConnectorHandshakeMetadata] + ) -> None: """ Set the KV connector handshake metadata for this connector. @@ -419,9 +423,7 @@ def _nixl_handshake_listener( ) if msg != GET_META_MSG: logger.warning("Connection listener got unexpected message %s", msg) - sock.send_multipart( - (identity, b"", encoded_data[target_tp_rank]) - ) + sock.send_multipart((identity, b"", encoded_data[target_tp_rank])) def get_num_new_matched_tokens( self, request: "Request", num_computed_tokens: int @@ -870,7 +872,6 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): total_num_kv_heads=self.model_config.get_total_num_kv_heads(), ) - def _nixl_handshake( self, host: str, @@ -896,9 +897,7 @@ def _nixl_handshake( # Send query for the request. with zmq_ctx(zmq.REQ, path) as sock: - msg = msgspec.msgpack.encode( - (GET_META_MSG, p_remote_rank) - ) + msg = msgspec.msgpack.encode((GET_META_MSG, p_remote_rank)) # Set receive timeout to 5 seconds to avoid hanging on dead server sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds sock.send(msg) diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index 5f6fb6b9a9e2..5b8adc7ffd97 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -9,6 +9,9 @@ from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorHandshakeMetadata, +) from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.tasks import SupportedTask @@ -17,7 +20,6 @@ from vllm.v1.engine import ReconfigureDistributedRequest from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput -from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorHandshakeMetadata from vllm.v1.worker.worker_base import WorkerBase logger = init_logger(__name__) @@ -175,7 +177,9 @@ def collective_rpc( ): raise NotImplementedError - def get_kv_connector_handshake_metadata(self) -> list[dict[int, KVConnectorHandshakeMetadata]]: + def get_kv_connector_handshake_metadata( + self, + ) -> list[dict[int, KVConnectorHandshakeMetadata]]: return self.collective_rpc("get_kv_connector_handshake_metadata") @overload From 7cc5a2d5fa2592dcfee78ef38aa4667357f9eae1 Mon Sep 17 00:00:00 2001 From: Guan Luo <41310872+GuanLuo@users.noreply.github.com> Date: Wed, 22 Oct 2025 19:05:07 +0800 Subject: [PATCH 12/19] fix: fix test Signed-off-by: Guan Luo <41310872+GuanLuo@users.noreply.github.com> --- .../kv_connector/unit/test_nixl_connector.py | 38 ++++++++----------- 1 file changed, 16 insertions(+), 22 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 3ea9ceda3755..93f2654c0446 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -345,29 +345,22 @@ def test_kv_transfer_handshake(dist_init): # Here we are testing the retrieval of NIXLAgentMetadata. # Knowing the implementation detail, we override the add_remote_agent # to validate the metadata received is the same as the one in prefill_connector. - received_metadata = None - - def mock_add_remote_agent( - agent_metadata: NixlAgentMetadata, - remote_tp_rank: int, - remote_tp_size: int, - ): - nonlocal received_metadata - received_metadata = (agent_metadata, remote_tp_rank, remote_tp_size) - return "remote_agent" - - decode_connector.connector_worker.add_remote_agent = mock_add_remote_agent + with patch.object( + decode_connector.connector_worker, "add_remote_agent" + ) as mock_add_remote_agent: + mock_add_remote_agent.return_type = "remote_agent" + + decode_connector.connector_worker._nixl_handshake( + kv_connector_metadata["remote_host"], + kv_connector_metadata["remote_port"], + kv_connector_metadata["tp_size"], + kv_connector_metadata["remote_engine_id"], + ) - decode_connector.connector_worker._nixl_handshake( - kv_connector_metadata["remote_host"], - kv_connector_metadata["remote_port"], - kv_connector_metadata["tp_size"], - kv_connector_metadata["remote_engine_id"], - ) - assert received_metadata is not None - assert received_metadata[1] == 0 # remote_tp_rank - assert received_metadata[2] == 1 # remote_tp_size - assert metadata[0][0] == received_metadata[0] + received_metadata = mock_add_remote_agent.call_args.args + assert received_metadata[1] == 0 # remote_tp_rank + assert received_metadata[2] == 1 # remote_tp_size + assert metadata[0] == received_metadata[0] # Need to shutdown the background thread to release NIXL side channel port scheduler_connector.shutdown() @@ -703,6 +696,7 @@ def test_handshake_succeed_on_kv_cache_layout_mismatch_with_experimental( engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID, agent_metadata=FakeNixlWrapper.AGENT_METADATA, kv_caches_base_addr=[0], + device_id=[0], num_blocks=1, # prefill TP=1, decode TP=2, remote block_lens is double to local block_lens=[i * 2 for i in worker.block_len_per_layer], From 84c6130a948fa4f835ac152d1cb62e7f564c75cc Mon Sep 17 00:00:00 2001 From: Guan Luo <41310872+GuanLuo@users.noreply.github.com> Date: Fri, 24 Oct 2025 10:42:20 +0800 Subject: [PATCH 13/19] chore: address comment Signed-off-by: Guan Luo <41310872+GuanLuo@users.noreply.github.com> --- docs/features/nixl_connector_usage.md | 2 +- .../kv_connector/unit/test_nixl_connector.py | 6 ++-- .../kv_connector/v1/nixl_connector.py | 32 +++++++------------ 3 files changed, 16 insertions(+), 24 deletions(-) diff --git a/docs/features/nixl_connector_usage.md b/docs/features/nixl_connector_usage.md index 605398652ee0..1ce038f4d652 100644 --- a/docs/features/nixl_connector_usage.md +++ b/docs/features/nixl_connector_usage.md @@ -81,7 +81,7 @@ python tests/v1/kv_connector/nixl_integration/toy_proxy_server.py \ - Default: 5600 - **Required for both prefiller and decoder instances** - Each vLLM worker needs a unique port on its host; using the same port number across different hosts is fine - - For TP/DP deployments, each worker's port on a node is computed as: base_port + dp_rank * tp_size + tp_rank (e.g., with `--tensor-parallel-size=4` and base_port=5600, tp_rank 0..3 use ports 5600, 5601, 5602, 5603 on that node). + - For TP/DP deployments, each worker's port on a node is computed as: base_port + dp_rank (e.g., with `--data-parallel-size=2` and base_port=5600, dp_rank 0..1 use port 5600, 5601 on that node). - Used for the initial NIXL handshake between the prefiller and the decoder - `VLLM_NIXL_SIDE_CHANNEL_HOST`: Host for side channel communication diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 93f2654c0446..e98ec53d4816 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -396,7 +396,7 @@ def _nixl_handshake( engine_id=self.REMOTE_ENGINE_ID, agent_metadata=FakeNixlWrapper.AGENT_METADATA, kv_caches_base_addr=[0], - device_id=[0], + device_id=0, num_blocks=1, block_lens=self.block_len_per_layer, attn_backend_name=self.backend_name, @@ -643,7 +643,7 @@ def test_handshake_fails_on_kv_cache_layout_mismatch(self, dist_init): engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID, agent_metadata=FakeNixlWrapper.AGENT_METADATA, kv_caches_base_addr=[0], - device_id=[0], + device_id=0, num_blocks=1, block_lens=worker.block_len_per_layer, attn_backend_name=worker.backend_name, @@ -696,7 +696,7 @@ def test_handshake_succeed_on_kv_cache_layout_mismatch_with_experimental( engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID, agent_metadata=FakeNixlWrapper.AGENT_METADATA, kv_caches_base_addr=[0], - device_id=[0], + device_id=0, num_blocks=1, # prefill TP=1, decode TP=2, remote block_lens is double to local block_lens=[i * 2 for i in worker.block_len_per_layer], diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 817a40bf0f79..f9271c828559 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -94,7 +94,7 @@ class NixlAgentMetadata(KVConnectorHandshakeMetadata): engine_id: str agent_metadata: bytes kv_caches_base_addr: list[int] - device_id: list[int] + device_id: int num_blocks: int block_lens: list[int] attn_backend_name: str @@ -789,7 +789,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # Map of engine_id -> kv_caches_base_addr. For TP case, each local # rank will still only pull from a single remote TP worker. self.kv_caches_base_addr: dict[EngineId, list[int]] = {} - self.device_id: dict[EngineId, list[int]] = {} + self.device_id: dict[EngineId, int] = {} # Number of NIXL regions. Currently one region per cache # (so 1 per layer for MLA, otherwise 2 per layer) @@ -1027,8 +1027,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): caches_data = [] # With hybrid allocator, layers can share a kv cache tensor seen_base_addresses = [] - # Map from address to device ID - seen_addresses_device_id = [] + device_id = 0 # Note(tms): I modified this from the original region setup code. # K and V are now in different regions. Advantage is that we can @@ -1052,10 +1051,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): continue seen_base_addresses.append(base_addr) - # Need to make sure the device ID is non-negative for NIXL, - # Torch uses -1 to indicate CPU tensors while NIXL uses explicit - # memory type. - seen_addresses_device_id.append(max(cache.get_device(), 0)) curr_tensor_size_bytes = cache.numel() * cache.element_size() if tensor_size_bytes is None: @@ -1081,9 +1076,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # Need to make sure the device ID is non-negative for NIXL, # Torch uses -1 to indicate CPU tensors while NIXL uses explicit # memory type. - caches_data.append( - (base_addr, curr_tensor_size_bytes, max(cache.get_device(), 0), "") - ) + device_id = max(cache.get_device(), 0) + caches_data.append((base_addr, curr_tensor_size_bytes, device_id, "")) logger.debug( "Different block lengths collected: %s", set(self.block_len_per_layer) @@ -1092,7 +1086,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): assert self.num_blocks != 0 self.kv_caches_base_addr[self.engine_id] = seen_base_addresses - self.device_id[self.engine_id] = seen_addresses_device_id + self.device_id[self.engine_id] = device_id self.num_regions = len(caches_data) self.num_layers = len(xfer_buffers.keys()) @@ -1119,9 +1113,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # Register local/src descr for NIXL xfer. blocks_data = [] - for i, (base_addr, device_id) in enumerate( - zip(seen_base_addresses, seen_addresses_device_id) - ): + for i, base_addr in enumerate(seen_base_addresses): kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) # NOTE With heter-TP, more blocks are prepared than what are # needed as self.num_blocks >= nixl_agent_meta.num_blocks. We @@ -1280,9 +1272,7 @@ def add_remote_agent( # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..]. # Register all remote blocks, but only the corresponding kv heads. - for i, (base_addr, device_id) in enumerate( - zip(nixl_agent_meta.kv_caches_base_addr, nixl_agent_meta.device_id) - ): + for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) rank_offset = ( self.tp_rank % tp_ratio * kv_block_len if not replicates_kv_cache else 0 @@ -1294,7 +1284,7 @@ def add_remote_agent( # self.block_len == remote_block_len//tp_ratio bytes. addr = base_addr + block_offset + rank_offset # (addr, len, device id) - blocks_data.append((addr, kv_block_len, device_id)) + blocks_data.append((addr, kv_block_len, nixl_agent_meta.device_id)) if self._use_flashinfer: # With FlashInfer index V separately to allow head splitting. @@ -1302,7 +1292,9 @@ def add_remote_agent( block_offset = block_id * nixl_agent_meta.block_lens[i] addr = base_addr + block_offset + rank_offset v_addr = addr + nixl_agent_meta.block_lens[i] // 2 - blocks_data.append((v_addr, kv_block_len, device_id)) + blocks_data.append( + (v_addr, kv_block_len, nixl_agent_meta.device_id) + ) logger.debug( "Created %s blocks for dst engine %s with remote rank %s and local rank %s", From 70095536c843052699804d595adafcb0c283acbf Mon Sep 17 00:00:00 2001 From: Guan Luo <41310872+GuanLuo@users.noreply.github.com> Date: Fri, 24 Oct 2025 11:37:04 +0800 Subject: [PATCH 14/19] fix: use mock nixl lib Signed-off-by: Guan Luo <41310872+GuanLuo@users.noreply.github.com> --- tests/v1/kv_connector/unit/test_nixl_connector.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 6e68e7f09ea6..466884d88f22 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -284,6 +284,10 @@ def test_prompt_less_than_block_size(): assert len(scheduler_output.scheduled_new_reqs) == 0 +@patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper, +) def test_kv_transfer_handshake(dist_init): """Unit test for basic NixlConnector interface functionality.""" From c72854544629a28389ef0bc87b904907ee47c92f Mon Sep 17 00:00:00 2001 From: GuanLuo <41310872+GuanLuo@users.noreply.github.com> Date: Wed, 29 Oct 2025 16:26:13 +0800 Subject: [PATCH 15/19] remove addition newline Signed-off-by: GuanLuo <41310872+GuanLuo@users.noreply.github.com> --- vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 05693ffd8fc8..7f21019b2f26 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1732,7 +1732,6 @@ def _read_blocks( # Get descs ids. local_block_descs_ids: np.ndarray - remote_block_descs_ids: np.ndarray if not self.block_window_per_layer: # Default case: assume global attention From ee1d4afaa80b588a8f111c9f7727e82dcb6bd741 Mon Sep 17 00:00:00 2001 From: Guan Luo <41310872+GuanLuo@users.noreply.github.com> Date: Thu, 30 Oct 2025 20:48:37 +0800 Subject: [PATCH 16/19] fix: fix pre_commit Signed-off-by: Guan Luo <41310872+GuanLuo@users.noreply.github.com> --- vllm/v1/core/sched/scheduler.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index ad6fbee2ec08..a8109ff20e79 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -93,7 +93,6 @@ def __init__( ) connector_vllm_config = copy.copy(self.vllm_config) - connector_vllm_config.kv_cache_config = copy.copy(kv_cache_config) self.connector = KVConnectorFactory.create_connector( config=connector_vllm_config, role=KVConnectorRole.SCHEDULER ) @@ -1335,7 +1334,7 @@ def _connector_finished( assert len(self.kv_cache_config.kv_cache_groups) == 1 return self.connector.request_finished(request, block_ids[0]) else: - return self.connector.request_finished(request, block_ids) + return self.connector.request_finished_all_groups(request, block_ids) def _update_waiting_for_remote_kv(self, request: Request) -> bool: """ From 552b0fb91189c5b575b97e519f4e13d856158aea Mon Sep 17 00:00:00 2001 From: Guan Luo <41310872+GuanLuo@users.noreply.github.com> Date: Thu, 30 Oct 2025 21:02:52 +0800 Subject: [PATCH 17/19] fix: fix up Signed-off-by: Guan Luo <41310872+GuanLuo@users.noreply.github.com> --- vllm/v1/core/sched/scheduler.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index a8109ff20e79..3ca0407e6426 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -1334,7 +1334,13 @@ def _connector_finished( assert len(self.kv_cache_config.kv_cache_groups) == 1 return self.connector.request_finished(request, block_ids[0]) else: - return self.connector.request_finished_all_groups(request, block_ids) + # NOTE(gluo): Ignoring mypy error as request_finished_all_groups() is + # not defined in KVConnectorBase_V1. However, supports_hma() implies + # the connector implements SupportsHMA interface which provides this + # method. + # This ignore is only for unblocking mypy check, should not be needed + # after we merge the code path as mentioned in the above comment. + return self.connector.request_finished_all_groups(request, block_ids) # type: ignore[attr-defined] def _update_waiting_for_remote_kv(self, request: Request) -> bool: """ From 696dd7b236ad141b64859b7f30cb14054c450166 Mon Sep 17 00:00:00 2001 From: Guan Luo <41310872+GuanLuo@users.noreply.github.com> Date: Thu, 30 Oct 2025 21:21:46 +0800 Subject: [PATCH 18/19] fix: fix up Signed-off-by: Guan Luo <41310872+GuanLuo@users.noreply.github.com> --- vllm/model_executor/models/config.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index b0a48a9f1d45..2c8dd6eacbe5 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -334,7 +334,8 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: return # Save the user input before it gets modified by MambaModelConfig - mamba_block_size = vllm_config.cache_config.mamba_block_size + # NOTE(gluo): F841 is suppressed as mamba_block_size is used but ruff flags it + mamba_block_size = vllm_config.cache_config.mamba_block_size # noqa: F841 # Enable FULL_AND_PIECEWISE by default MambaModelConfig.verify_and_update_config(vllm_config) From 99c039a48a2e13295f367ab472c32bd5b53171aa Mon Sep 17 00:00:00 2001 From: Guan Luo <41310872+GuanLuo@users.noreply.github.com> Date: Fri, 31 Oct 2025 13:05:06 +0800 Subject: [PATCH 19/19] chore: undo unrelated change Signed-off-by: Guan Luo <41310872+GuanLuo@users.noreply.github.com> --- vllm/model_executor/models/config.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 9dcb4681a672..7150977e9266 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -334,8 +334,7 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: return # Save the user input before it gets modified by MambaModelConfig - # NOTE(gluo): F841 is suppressed as mamba_block_size is used but ruff flags it - mamba_block_size = vllm_config.cache_config.mamba_block_size # noqa: F841 + mamba_block_size = vllm_config.cache_config.mamba_block_size # Enable FULL_AND_PIECEWISE by default MambaModelConfig.verify_and_update_config(vllm_config)