Skip to content
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/features/nixl_connector_usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ python tests/v1/kv_connector/nixl_integration/toy_proxy_server.py \
- Default: 5600
- **Required for both prefiller and decoder instances**
- Each vLLM worker needs a unique port on its host; using the same port number across different hosts is fine
- For TP/DP deployments, each worker's port on a node is computed as: base_port + dp_rank * tp_size + tp_rank (e.g., with `--tensor-parallel-size=4` and base_port=5600, tp_rank 0..3 use ports 5600, 5601, 5602, 5603 on that node).
- For TP/DP deployments, each worker's port on a node is computed as: base_port + dp_rank (e.g., with `--data-parallel-size=2` and base_port=5600, dp_rank 0..1 use port 5600, 5601 on that node).
- Used for the initial NIXL handshake between the prefiller and the decoder

- `VLLM_NIXL_SIDE_CHANNEL_HOST`: Host for side channel communication
Expand Down
106 changes: 102 additions & 4 deletions tests/v1/kv_connector/unit/test_nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
NixlAgentMetadata,
NixlConnector,
NixlConnectorMetadata,
NixlConnectorScheduler,
NixlConnectorWorker,
NixlKVConnectorStats,
)
Expand Down Expand Up @@ -283,6 +284,92 @@ def test_prompt_less_than_block_size():
assert len(scheduler_output.scheduled_new_reqs) == 0


@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
def test_kv_transfer_handshake(dist_init):
"""Unit test for basic NixlConnector interface functionality."""

# Test setup, we creates a scheduler that contains a NixlConnector
# of role SCHEDULER, and expect it to be serving NixlAgentMetadata from
# all workers of the instance.
vllm_config = create_vllm_config()
# in case the test runs on non-GPU machine
vllm_config.kv_transfer_config.kv_buffer_device = "cpu"
scheduler = create_scheduler(vllm_config)

# Create two NixlConnector of role WORKER, one is the worker of
# the scheduler (prefill), the other is a worker of decode instance.

# Prefill connector will register KV cache to populate proper handshake
# metadata.
prefill_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
"layer2": shared_tensor,
}
prefill_connector.register_kv_caches(kv_caches)

# Simulate EngineCore initialization that would
# gather connector metadata from all workers, the scheduler connector
# expects metadata to be in dict[int, KVConnectorHandshakeMetadata],
# where the first key is the dp_rank, the second key is the tp_rank.
metadata = {0: prefill_connector.get_handshake_metadata()}
scheduler_connector = scheduler.get_kv_connector()
scheduler_connector.set_xfer_handshake_metadata(metadata)

# Simulate a request that finishes prefill, which returns
# corresponding NixlConnectorMetadata for decode instance.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))

request = create_request(
request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_decode=True,
)
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
delay, kv_connector_metadata = scheduler.get_kv_connector().request_finished(
request, [0, 1, 2]
)
assert delay

# Decode connector will be able to create handshake with the prefill connector.
decode_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)

# Here we are testing the retrieval of NIXLAgentMetadata.
# Knowing the implementation detail, we override the add_remote_agent
# to validate the metadata received is the same as the one in prefill_connector.
with patch.object(
decode_connector.connector_worker, "add_remote_agent"
) as mock_add_remote_agent:
mock_add_remote_agent.return_type = "remote_agent"

decode_connector.connector_worker._nixl_handshake(
kv_connector_metadata["remote_host"],
kv_connector_metadata["remote_port"],
kv_connector_metadata["tp_size"],
kv_connector_metadata["remote_engine_id"],
)

received_metadata = mock_add_remote_agent.call_args.args
assert received_metadata[1] == 0 # remote_tp_rank
assert received_metadata[2] == 1 # remote_tp_size
assert metadata[0] == received_metadata[0]

# Need to shutdown the background thread to release NIXL side channel port
scheduler_connector.shutdown()


class FakeNixlConnectorWorker(NixlConnectorWorker):
REMOTE_ENGINE_ID = "remote_engine"

Expand Down Expand Up @@ -313,6 +400,7 @@ def _nixl_handshake(
engine_id=self.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
device_id=0,
num_blocks=1,
block_lens=self.block_len_per_layer,
attn_backend_name=self.backend_name,
Expand Down Expand Up @@ -559,6 +647,7 @@ def test_handshake_fails_on_kv_cache_layout_mismatch(self, dist_init):
engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
device_id=0,
num_blocks=1,
block_lens=worker.block_len_per_layer,
attn_backend_name=worker.backend_name,
Expand Down Expand Up @@ -611,6 +700,7 @@ def test_handshake_succeed_on_kv_cache_layout_mismatch_with_experimental(
engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
device_id=0,
num_blocks=1,
# prefill TP=1, decode TP=2, remote block_lens is double to local
block_lens=[i * 2 for i in worker.block_len_per_layer],
Expand Down Expand Up @@ -1004,6 +1094,8 @@ def _run_abort_timeout_test(llm: LLM, timeout: int):
_ = llm.generate([f"What is the capital of France? {padding}"], sampling_params)
# Request-0 times out and is cleared!
assert "0" not in req_to_blocks
# Need to shutdown the background thread to release NIXL side channel port
llm.llm_engine.engine_core.shutdown()


def test_register_kv_caches(dist_init):
Expand Down Expand Up @@ -1176,13 +1268,15 @@ def test_shutdown_cleans_up_resources(dist_init):
"""Test that shutdown() properly cleans up all resources."""
vllm_config = create_vllm_config()

scheduler = NixlConnectorScheduler(
vllm_config, vllm_config.kv_transfer_config.engine_id
)
worker = NixlConnectorWorker(vllm_config, vllm_config.kv_transfer_config.engine_id)
nixl_wrapper = worker.nixl_wrapper

with (
patch.object(worker, "_handshake_initiation_executor") as mock_exec,
patch.object(worker, "_nixl_handshake_listener_t") as mock_listener,
patch.object(worker, "_nixl_handshake_listener_stop_event") as mock_event,
patch.object(scheduler, "_nixl_handshake_listener_t") as mock_listener,
patch.object(nixl_wrapper, "release_xfer_handle") as mock_rel_xfer,
patch.object(nixl_wrapper, "release_dlist_handle") as mock_rel_dlist,
patch.object(nixl_wrapper, "remove_remote_agent") as mock_rem_agent,
Expand All @@ -1203,8 +1297,12 @@ def test_shutdown_cleans_up_resources(dist_init):
worker.shutdown()

mock_exec.shutdown.assert_called_with(wait=False)
mock_event.set.assert_called_once()
mock_listener.join.assert_called_once_with(timeout=1.0)

# Same sequence on scheduler.shutdown()
scheduler.shutdown()
scheduler.shutdown()
scheduler.shutdown()
mock_listener.join.assert_called_once()

mock_rel_xfer.assert_called_once_with(123)
assert mock_rel_dlist.call_count == 2
Expand Down
32 changes: 32 additions & 0 deletions vllm/distributed/kv_transfer/kv_connector/v1/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,15 @@ class KVConnectorRole(enum.Enum):
WORKER = 1


class KVConnectorHandshakeMetadata(ABC): # noqa: B024
"""
Metadata used for out of band connector handshake between
P/D workers. This needs to serializeable.
"""

pass


class KVConnectorMetadata(ABC): # noqa: B024
"""
Abstract Metadata used to communicate between the
Expand Down Expand Up @@ -276,6 +285,18 @@ def get_kv_connector_stats(self) -> Optional["KVConnectorStats"]:
"""
return None

def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
"""
Get the KVConnector handshake metadata for this connector.
This metadata is used for out-of-band connector handshake
between P/D workers.

Returns:
KVConnectorHandshakeMetadata: the handshake metadata.
None if no handshake metadata is available.
"""
return None

# ==============================
# Scheduler-side methods
# ==============================
Expand Down Expand Up @@ -432,3 +453,14 @@ def build_kv_connector_stats(
which can implement custom aggregation logic on the data dict.
"""
return None

def set_xfer_handshake_metadata(
self, metadata: dict[int, KVConnectorHandshakeMetadata]
) -> None:
"""
Set the KV connector handshake metadata for this connector.

Args:
metadata (KVConnectorHandshakeMetadata): the handshake metadata to set.
"""
return None
Loading