Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 104 additions & 2 deletions tests/v1/kv_connector/unit/test_nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
NixlAgentMetadata,
NixlConnector,
NixlConnectorMetadata,
NixlConnectorScheduler,
NixlConnectorWorker,
NixlKVConnectorStats,
)
Expand Down Expand Up @@ -283,6 +284,95 @@ def test_prompt_less_than_block_size():
assert len(scheduler_output.scheduled_new_reqs) == 0


def test_kv_transfer_metadata(dist_init):
"""Unit test for basic NixlConnector interface functionality."""

# Test setup, we creates a scheduler that contains a NixlConnector
# of role SCHEDULER, and expect it to be serving NixlAgentMetadata from
# all workers of the instance.
vllm_config = create_vllm_config()
# in case the test runs on non-GPU machine
vllm_config.kv_transfer_config.kv_buffer_device = "cpu"
scheduler = create_scheduler(vllm_config)

# Create two NixlConnector of role WORKER, one is the worker of
# the scheduler (prefill), the other is a worker of decode instance.

# Prefill connector will register KV cache to populate proper handshake
# metadata.
prefill_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
"layer2": shared_tensor,
}
prefill_connector.register_kv_caches(kv_caches)

# Simulate EngineCore initialization that would
# gather connector metadata from all workers, the scheduler connector
# expects metadata to be in Dict[int, Dict[int, KVConnectorHandshakeMetadata]],
# where the first key is the dp_rank, the second key is the tp_rank.
metadata = {0: {0: prefill_connector.get_handshake_metadata()}}
scheduler_connector = scheduler.get_kv_connector()
scheduler_connector.set_xfer_handshake_metadata(metadata)

# Simulate a request that finishes prefill, which returns
# corresponding NixlConnectorMetadata for decode instance.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))

request = create_request(
request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_decode=True,
)
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
delay, kv_connector_metadata = scheduler.get_kv_connector().request_finished(
request, [0, 1, 2]
)
assert delay

# Decode connector will be able to create handshake with the prefill connector.
decode_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)

# Here we are testing the retrieval of NIXLAgentMetadata.
# Knowing the implementation detail, we override the add_remote_agent
# to validate the metadata received is the same as the one in prefill_connector.
received_metadata = None

def mock_add_remote_agent(
agent_metadata: NixlAgentMetadata,
remote_tp_rank: int,
remote_tp_size: int,
):
nonlocal received_metadata
received_metadata = (agent_metadata, remote_tp_rank, remote_tp_size)
return "remote_agent"

decode_connector.connector_worker.add_remote_agent = mock_add_remote_agent

decode_connector.connector_worker._nixl_handshake(
kv_connector_metadata["remote_host"],
kv_connector_metadata["remote_port"],
kv_connector_metadata["tp_size"],
kv_connector_metadata["remote_engine_id"],
)
assert received_metadata is not None
assert received_metadata[1] == 0 # remote_tp_rank
assert received_metadata[2] == 1 # remote_tp_size
assert metadata[0][0] == received_metadata[0]

# Need to shutdown the background thread to release NIXL side channel port
scheduler_connector.shutdown()


class FakeNixlConnectorWorker(NixlConnectorWorker):
REMOTE_ENGINE_ID = "remote_engine"

Expand Down Expand Up @@ -310,6 +400,7 @@ def _nixl_handshake(
engine_id=self.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
device_id=[0],
num_blocks=1,
block_lens=self.block_len_per_layer,
attn_backend_name=self.backend_name,
Expand Down Expand Up @@ -556,6 +647,7 @@ def test_handshake_fails_on_kv_cache_layout_mismatch(self, dist_init):
engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
device_id=[0],
num_blocks=1,
block_lens=worker.block_len_per_layer,
attn_backend_name=worker.backend_name,
Expand Down Expand Up @@ -873,6 +965,8 @@ def _run_abort_timeout_test(llm_kwargs: dict, timeout: int):
_ = llm.generate([f"What is the capital of France? {padding}"], sampling_params)
# Request-0 times out and is cleared!
assert "0" not in req_to_blocks
# Need to shutdown the background thread to release NIXL side channel port
llm.llm_engine.engine_core.shutdown()


def test_register_kv_caches(dist_init):
Expand Down Expand Up @@ -1042,12 +1136,15 @@ def test_shutdown_cleans_up_resources(dist_init):
"""Test that shutdown() properly cleans up all resources."""
vllm_config = create_vllm_config()

scheduler = NixlConnectorScheduler(
vllm_config, vllm_config.kv_transfer_config.engine_id
)
worker = NixlConnectorWorker(vllm_config, vllm_config.kv_transfer_config.engine_id)
nixl_wrapper = worker.nixl_wrapper

with (
patch.object(worker, "_handshake_initiation_executor") as mock_exec,
patch.object(worker, "_nixl_handshake_listener_t") as mock_listener,
patch.object(scheduler, "_nixl_handshake_listener_t") as mock_listener,
patch.object(nixl_wrapper, "release_xfer_handle") as mock_rel_xfer,
patch.object(nixl_wrapper, "release_dlist_handle") as mock_rel_dlist,
patch.object(nixl_wrapper, "remove_remote_agent") as mock_rem_agent,
Expand All @@ -1066,7 +1163,12 @@ def test_shutdown_cleans_up_resources(dist_init):
worker.shutdown()

mock_exec.shutdown.assert_called_with(wait=False)
mock_listener.join.assert_called_once_with(timeout=0)

# Same sequence on scheduler.shutdown()
scheduler.shutdown()
scheduler.shutdown()
scheduler.shutdown()
mock_listener.join.assert_called_once()

mock_rel_xfer.assert_called_once_with(123)
assert mock_rel_dlist.call_count == 2
Expand Down
30 changes: 30 additions & 0 deletions vllm/distributed/kv_transfer/kv_connector/v1/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,15 @@ class KVConnectorRole(enum.Enum):
WORKER = 1


class KVConnectorHandshakeMetadata(ABC): # noqa: B024
"""
Metadata used for out of band connector handshake between
P/D workers. This needs to serializeable.
"""

pass


class KVConnectorMetadata(ABC): # noqa: B024
"""
Abstract Metadata used to communicate between the
Expand Down Expand Up @@ -276,6 +285,18 @@ def get_kv_connector_stats(self) -> Optional["KVConnectorStats"]:
"""
return None

def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
"""
Get the KVConnector handshake metadata for this connector.
This metadata is used for out-of-band connector handshake
between P/D workers.

Returns:
KVConnectorHandshakeMetadata: the handshake metadata.
None if no handshake metadata is available.
"""
return None

# ==============================
# Scheduler-side methods
# ==============================
Expand Down Expand Up @@ -431,3 +452,12 @@ def build_kv_connector_stats(
which can implement custom aggregation logic on the data dict.
"""
return None

def set_xfer_handshake_metadata(self, metadata: dict[int, dict[int, dict]]) -> None:
"""
Set the KV connector handshake metadata for this connector.

Args:
metadata (dict): the handshake metadata to set.
"""
return None
Loading