Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions vllm/distributed/kv_transfer/kv_connector/v1/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,15 @@ class KVConnectorRole(enum.Enum):
WORKER = 1


class KVConnectorHandshakeMetadata(ABC): # noqa: B024
"""
Metadata used for out of band connector handshakeandshake between
P/D workers. This needs to serializeable.
"""

pass


class KVConnectorMetadata(ABC): # noqa: B024
"""
Abstract Metadata used to communicate between the
Expand Down Expand Up @@ -272,6 +281,19 @@ def get_kv_connector_stats(self) -> Optional["KVConnectorStats"]:
"""
return None

@property
def handshake_metadata(self) -> Optional[KVConnectorHandshakeMetadata]:
"""
Get the KVConnector handshake metadata for this connector.
This metadata is used for out-of-band connector handshake
between P/D workers.

Returns:
KVConnectorHandshakeMetadata: the handshake metadata.
None if no handshake metadata is available.
"""
return None

# ==============================
# Scheduler-side methods
# ==============================
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import copy
import logging
import math
Expand All @@ -10,15 +9,12 @@
import time
import uuid
from collections import defaultdict
from collections.abc import Iterator
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any

import msgspec
import numpy as np
import torch
import zmq

from vllm import envs
from vllm.attention.backends.registry import _Backend, backend_name_to_enum
Expand All @@ -31,6 +27,18 @@
KVConnectorRole,
)
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats

# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector._nixl_import import ( # noqa: E501
NixlWrapper,
nixl_agent_config,
nixlXferTelemetry,
)
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.handshake import ( # noqa: E501
HandshakeStrategy,
NixlAgentMetadata,
ZmqHandshakeStrategy,
)
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
Expand All @@ -40,7 +48,6 @@
from vllm.forward_context import ForwardContext
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import make_zmq_path, make_zmq_socket
from vllm.v1.attention.backends.utils import get_kv_cache_layout
from vllm.v1.core.sched.output import SchedulerOutput

Expand All @@ -57,24 +64,6 @@

logger = init_logger(__name__)

# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used
try:
from nixl._api import nixl_agent as NixlWrapper
from nixl._bindings import nixlXferTelemetry

logger.info("NIXL is available")
except ImportError:
logger.warning("NIXL is not available")
NixlWrapper = None
nixlXferTelemetry = None


try:
from nixl._api import nixl_agent_config
except ImportError:
nixl_agent_config = None
logger.warning("NIXL agent config is not available")

# Supported platforms and types of kv transfer buffer.
# {device: tuple of supported kv buffer types}
_NIXL_SUPPORTED_DEVICE = {
Expand All @@ -89,21 +78,6 @@
_NIXL_SUPPORTED_DEVICE.update(current_platform.get_nixl_supported_devices())


class NixlAgentMetadata(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
# required for @cached_property.
dict=True,
):
engine_id: str
agent_metadata: bytes
kv_caches_base_addr: list[int]
num_blocks: int
block_lens: list[int]
attn_backend_name: str
kv_cache_layout: str


@dataclass
class ReqMeta:
local_block_ids: list[int]
Expand Down Expand Up @@ -244,6 +218,13 @@ def get_kv_connector_stats(self) -> KVConnectorStats | None:
assert self.connector_worker is not None
return self.connector_worker.get_kv_connector_stats()

@property
def handshake_metadata(self):
"""Get current handshake metadata from worker process."""
if self.connector_worker is None:
return None
return self.connector_worker.xfer_metadata

@classmethod
def build_kv_connector_stats(
cls, data: dict[str, Any] | None = None
Expand Down Expand Up @@ -625,8 +606,6 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
# requests that skipped transfer (handshake or transfer failures)
self._failed_recv_reqs: set[ReqId] = set()

# Background thread for handling new handshake requests.
self._nixl_handshake_listener_t: threading.Thread | None = None
# Background thread for initializing new NIXL handshakes.
self._handshake_initiation_executor = ThreadPoolExecutor(
# NIXL is not guaranteed to be thread-safe, limit 1 worker.
Expand Down Expand Up @@ -670,33 +649,15 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)
self.xfer_stats = NixlKVConnectorStats()

@staticmethod
def _nixl_handshake_listener(
metadata: NixlAgentMetadata,
ready_event: threading.Event,
base_port: int,
tp_rank: int,
):
"""Background thread for getting new NIXL handshakes."""
# NOTE(rob): this is a simple implementation. We will move
# to a better approach via HTTP endpoint soon.

encoder = msgspec.msgpack.Encoder()
encoded_data = encoder.encode(metadata)
size_in_bytes = len(encoded_data)
logger.debug("Size of encoded NixlAgentMetadata: %s bytes", str(size_in_bytes))

# Listen for new requests for metadata.
host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
path = make_zmq_path("tcp", host, base_port + tp_rank)
logger.debug("Starting listening on path: %s", path)
with zmq_ctx(zmq.ROUTER, path) as sock:
ready_event.set()
while True:
identity, _, msg = sock.recv_multipart()
if msg != GET_META_MSG:
logger.warning("Connection listener got unexpected message %s", msg)
sock.send_multipart((identity, b"", encoded_data))
self._handshake_strategy: HandshakeStrategy = ZmqHandshakeStrategy(
self.nixl_wrapper,
self.tp_rank,
self.world_size,
self.side_channel_port,
self.engine_id,
self.add_remote_agent,
self._handshake_lock,
)

def _nixl_handshake(
self,
Expand All @@ -706,56 +667,10 @@ def _nixl_handshake(
expected_engine_id: str,
) -> dict[int, str]:
"""Do a NIXL handshake with a remote instance."""

start_time = time.perf_counter()

# NOTE(rob): we need each rank to have a unique port. This is
# a hack to keep us moving. We will switch when moving to etcd
# or where we have a single ZMQ socket in the scheduler.

# Handshake only with the remote TP rank that current local rank will
# pull from. With homogeneous TP it happens to be the same rank_i.
tp_ratio = self._tp_size[self.engine_id] // remote_tp_size
p_remote_rank = self.tp_rank // tp_ratio
path = make_zmq_path("tcp", host, port + p_remote_rank)
logger.debug(
"Querying metadata on path: %s at remote rank %s", path, p_remote_rank
return self._handshake_strategy.initiate_handshake(
host, port, remote_tp_size, expected_engine_id
)

# Send query for the request.
with zmq_ctx(zmq.REQ, path) as sock:
# Set receive timeout to 5 seconds to avoid hanging on dead server
sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds
sock.send(GET_META_MSG)
metadata_bytes = sock.recv()
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
metadata = decoder.decode(metadata_bytes)
got_metadata_time = time.perf_counter()
logger.debug(
"NIXL handshake: get metadata took: %s", got_metadata_time - start_time
)

# Ensure engine id matches.
if metadata.engine_id != expected_engine_id:
raise RuntimeError(
f"Remote NIXL agent engine ID mismatch. "
f"Expected {expected_engine_id},"
f"received {metadata.engine_id}."
)

# Register Remote agent.
remote_agent_name = self.add_remote_agent(
metadata, p_remote_rank, remote_tp_size
)
setup_agent_time = time.perf_counter()
logger.debug(
"NIXL handshake: add agent took: %s",
setup_agent_time - got_metadata_time,
)

# Remote rank -> agent name.
return {p_remote_rank: remote_agent_name}

def initialize_host_xfer_buffer(self, kv_caches: dict[str, torch.Tensor]) -> None:
"""
Initialize transfer buffer in CPU mem for accelerators
Expand Down Expand Up @@ -996,7 +911,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
assert len(self.block_window_per_layer) == self.num_layers

# After KV Caches registered, listen for new connections.
metadata = NixlAgentMetadata(
self.xfer_metadata = NixlAgentMetadata(
engine_id=self.engine_id,
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
Expand All @@ -1005,15 +920,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
attn_backend_name=self.backend_name,
kv_cache_layout=self.kv_cache_layout,
)
ready_event = threading.Event()
self._nixl_handshake_listener_t = threading.Thread(
target=self._nixl_handshake_listener,
args=(metadata, ready_event, self.side_channel_port, self.tp_rank),
daemon=True,
name="nixl_handshake_listener",
)
self._nixl_handshake_listener_t.start()
ready_event.wait() # Wait for listener ZMQ socket to be ready.
self._handshake_strategy.setup_listener(self.xfer_metadata)

def add_remote_agent(
self,
Expand Down Expand Up @@ -1605,9 +1512,7 @@ def get_block_ids_with_load_errors(self) -> set[int]:
def shutdown(self):
"""Shutdown the connector worker."""
self._handshake_initiation_executor.shutdown(wait=False)
if self._nixl_handshake_listener_t is not None:
self._nixl_handshake_listener_t.join(timeout=0)
self._nixl_handshake_listener_t = None
self._handshake_strategy.cleanup()
for handles in self._recving_transfers.values():
for handle, _ in handles:
self.nixl_wrapper.release_xfer_handle(handle)
Expand All @@ -1627,24 +1532,6 @@ def shutdown(self):
self._registered_descs.clear()


@contextlib.contextmanager
def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]:
"""Context manager for a ZMQ socket"""

if socket_type not in (zmq.ROUTER, zmq.REQ):
raise ValueError(f"Unexpected socket type: {socket_type}")

ctx: zmq.Context | None = None
try:
ctx = zmq.Context() # type: ignore[attr-defined]
yield make_zmq_socket(
ctx=ctx, path=addr, socket_type=socket_type, bind=socket_type == zmq.ROUTER
)
finally:
if ctx is not None:
ctx.destroy(linger=0)


@dataclass
class NixlKVConnectorStats(KVConnectorStats):
"""Container for transfer performance metrics"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Centralized lazy import for NIXL wrapper to avoid circular dependencies."""

from vllm.logger import init_logger

logger = init_logger(__name__)

# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used
try:
from nixl._api import nixl_agent as NixlWrapper
from nixl._bindings import nixlXferTelemetry

logger.info("NIXL is available")
except ImportError:
logger.warning("NIXL is not available")
NixlWrapper = None
nixlXferTelemetry = None

try:
from nixl._api import nixl_agent_config
except ImportError:
nixl_agent_config = None
logger.warning("NIXL agent config is not available")

__all__ = ["NixlWrapper", "nixlXferTelemetry", "nixl_agent_config"]
Loading
Loading