Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
5770410
fix: missing NIXL metadata for handshake initialization if instance s…
GuanLuo Oct 7, 2025
8fb478a
test: add test
GuanLuo Oct 10, 2025
f5e82f5
chore: style and fix test
GuanLuo Oct 10, 2025
aa41ea5
chore: fix pre-commit
GuanLuo Oct 10, 2025
d3abc46
fix: fix test
GuanLuo Oct 10, 2025
813dab1
fix: fix NIXL handshake listener cleanup
GuanLuo Oct 10, 2025
009ec48
fix: fix device ID use for NIXL memory registration
GuanLuo Oct 11, 2025
9e2c914
Merge branch 'main' into gluo/nixl
GuanLuo Oct 13, 2025
e52448f
chore: address comment
GuanLuo Oct 14, 2025
102f543
Merge branch 'main' into gluo/nixl
GuanLuo Oct 14, 2025
fa8357f
fix: fix test and CPU case
GuanLuo Oct 20, 2025
a535327
Merge branch 'main' into gluo/nixl
GuanLuo Oct 21, 2025
bb9f577
chore: address comment
GuanLuo Oct 22, 2025
aeb56df
Merge branch 'main' into gluo/nixl
GuanLuo Oct 22, 2025
c49b0bb
style: style
GuanLuo Oct 22, 2025
7cc5a2d
fix: fix test
GuanLuo Oct 22, 2025
84c6130
chore: address comment
GuanLuo Oct 24, 2025
13dce7c
Merge branch 'main' into gluo/nixl
GuanLuo Oct 24, 2025
7009553
fix: use mock nixl lib
GuanLuo Oct 24, 2025
47f08f4
Merge branch 'main' into gluo/nixl
NickLucche Oct 24, 2025
c042ef8
Merge branch 'main' into gluo/nixl
GuanLuo Oct 27, 2025
ea11194
Merge branch 'main' into gluo/nixl
NickLucche Oct 27, 2025
a20e248
Merge branch 'main' into gluo/nixl
GuanLuo Oct 28, 2025
b769ee5
Merge branch 'main' into gluo/nixl
GuanLuo Oct 29, 2025
c728545
remove addition newline
GuanLuo Oct 29, 2025
970d1c4
Merge branch 'main' into gluo/nixl
GuanLuo Oct 30, 2025
ee1d4af
fix: fix pre_commit
GuanLuo Oct 30, 2025
552b0fb
fix: fix up
GuanLuo Oct 30, 2025
da43425
Merge branch 'main' into gluo/nixl
GuanLuo Oct 30, 2025
696dd7b
fix: fix up
GuanLuo Oct 30, 2025
a4e8f47
Merge branch 'main' into gluo/nixl
GuanLuo Oct 31, 2025
99c039a
chore: undo unrelated change
GuanLuo Oct 31, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/features/nixl_connector_usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ python tests/v1/kv_connector/nixl_integration/toy_proxy_server.py \
- Default: 5600
- **Required for both prefiller and decoder instances**
- Each vLLM worker needs a unique port on its host; using the same port number across different hosts is fine
- For TP/DP deployments, each worker's port on a node is computed as: base_port + dp_rank * tp_size + tp_rank (e.g., with `--tensor-parallel-size=4` and base_port=5600, tp_rank 0..3 use ports 5600, 5601, 5602, 5603 on that node).
- For TP/DP deployments, each worker's port on a node is computed as: base_port + dp_rank (e.g., with `--data-parallel-size=2` and base_port=5600, dp_rank 0..1 use port 5600, 5601 on that node).
- Used for the initial NIXL handshake between the prefiller and the decoder

- `VLLM_NIXL_SIDE_CHANNEL_HOST`: Host for side channel communication
Expand Down
106 changes: 102 additions & 4 deletions tests/v1/kv_connector/unit/test_nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
NixlAgentMetadata,
NixlConnector,
NixlConnectorMetadata,
NixlConnectorScheduler,
NixlConnectorWorker,
NixlKVConnectorStats,
)
Expand Down Expand Up @@ -283,6 +284,92 @@ def test_prompt_less_than_block_size():
assert len(scheduler_output.scheduled_new_reqs) == 0


@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
def test_kv_transfer_handshake(dist_init):
"""Unit test for basic NixlConnector interface functionality."""

# Test setup, we creates a scheduler that contains a NixlConnector
# of role SCHEDULER, and expect it to be serving NixlAgentMetadata from
# all workers of the instance.
vllm_config = create_vllm_config()
# in case the test runs on non-GPU machine
vllm_config.kv_transfer_config.kv_buffer_device = "cpu"
scheduler = create_scheduler(vllm_config)

# Create two NixlConnector of role WORKER, one is the worker of
# the scheduler (prefill), the other is a worker of decode instance.

# Prefill connector will register KV cache to populate proper handshake
# metadata.
prefill_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
"layer2": shared_tensor,
}
prefill_connector.register_kv_caches(kv_caches)

# Simulate EngineCore initialization that would
# gather connector metadata from all workers, the scheduler connector
# expects metadata to be in dict[int, KVConnectorHandshakeMetadata],
# where the first key is the dp_rank, the second key is the tp_rank.
metadata = {0: prefill_connector.get_handshake_metadata()}
scheduler_connector = scheduler.get_kv_connector()
scheduler_connector.set_xfer_handshake_metadata(metadata)

# Simulate a request that finishes prefill, which returns
# corresponding NixlConnectorMetadata for decode instance.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))

request = create_request(
request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_decode=True,
)
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
delay, kv_connector_metadata = scheduler.get_kv_connector().request_finished(
request, [0, 1, 2]
)
assert delay

# Decode connector will be able to create handshake with the prefill connector.
decode_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)

# Here we are testing the retrieval of NIXLAgentMetadata.
# Knowing the implementation detail, we override the add_remote_agent
# to validate the metadata received is the same as the one in prefill_connector.
with patch.object(
decode_connector.connector_worker, "add_remote_agent"
) as mock_add_remote_agent:
mock_add_remote_agent.return_type = "remote_agent"

decode_connector.connector_worker._nixl_handshake(
kv_connector_metadata["remote_host"],
kv_connector_metadata["remote_port"],
kv_connector_metadata["tp_size"],
kv_connector_metadata["remote_engine_id"],
)

received_metadata = mock_add_remote_agent.call_args.args
assert received_metadata[1] == 0 # remote_tp_rank
assert received_metadata[2] == 1 # remote_tp_size
assert metadata[0] == received_metadata[0]

# Need to shutdown the background thread to release NIXL side channel port
scheduler_connector.shutdown()


class FakeNixlConnectorWorker(NixlConnectorWorker):
REMOTE_ENGINE_ID = "remote_engine"

Expand Down Expand Up @@ -313,6 +400,7 @@ def _nixl_handshake(
engine_id=self.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
device_id=0,
num_blocks=1,
block_lens=self.block_len_per_layer,
attn_backend_name=self.backend_name,
Expand Down Expand Up @@ -559,6 +647,7 @@ def test_handshake_fails_on_kv_cache_layout_mismatch(self, dist_init):
engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
device_id=0,
num_blocks=1,
block_lens=worker.block_len_per_layer,
attn_backend_name=worker.backend_name,
Expand Down Expand Up @@ -611,6 +700,7 @@ def test_handshake_succeed_on_kv_cache_layout_mismatch_with_experimental(
engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
device_id=0,
num_blocks=1,
# prefill TP=1, decode TP=2, remote block_lens is double to local
block_lens=[i * 2 for i in worker.block_len_per_layer],
Expand Down Expand Up @@ -1005,6 +1095,8 @@ def _run_abort_timeout_test(llm: LLM, timeout: int):
_ = llm.generate([f"What is the capital of France? {padding}"], sampling_params)
# Request-0 times out and is cleared!
assert "0" not in req_to_blocks
# Need to shutdown the background thread to release NIXL side channel port
llm.llm_engine.engine_core.shutdown()


def test_register_kv_caches(dist_init):
Expand Down Expand Up @@ -1177,13 +1269,15 @@ def test_shutdown_cleans_up_resources(dist_init):
"""Test that shutdown() properly cleans up all resources."""
vllm_config = create_vllm_config()

scheduler = NixlConnectorScheduler(
vllm_config, vllm_config.kv_transfer_config.engine_id
)
worker = NixlConnectorWorker(vllm_config, vllm_config.kv_transfer_config.engine_id)
nixl_wrapper = worker.nixl_wrapper

with (
patch.object(worker, "_handshake_initiation_executor") as mock_exec,
patch.object(worker, "_nixl_handshake_listener_t") as mock_listener,
patch.object(worker, "_nixl_handshake_listener_stop_event") as mock_event,
patch.object(scheduler, "_nixl_handshake_listener_t") as mock_listener,
patch.object(nixl_wrapper, "release_xfer_handle") as mock_rel_xfer,
patch.object(nixl_wrapper, "release_dlist_handle") as mock_rel_dlist,
patch.object(nixl_wrapper, "remove_remote_agent") as mock_rem_agent,
Expand All @@ -1204,8 +1298,12 @@ def test_shutdown_cleans_up_resources(dist_init):
worker.shutdown()

mock_exec.shutdown.assert_called_with(wait=False)
mock_event.set.assert_called_once()
mock_listener.join.assert_called_once_with(timeout=1.0)

# Same sequence on scheduler.shutdown()
scheduler.shutdown()
scheduler.shutdown()
scheduler.shutdown()
mock_listener.join.assert_called_once()

mock_rel_xfer.assert_called_once_with(123)
assert mock_rel_dlist.call_count == 2
Expand Down
32 changes: 32 additions & 0 deletions vllm/distributed/kv_transfer/kv_connector/v1/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,15 @@ class KVConnectorRole(enum.Enum):
WORKER = 1


class KVConnectorHandshakeMetadata(ABC): # noqa: B024
"""
Metadata used for out of band connector handshake between
P/D workers. This needs to serializeable.
"""

pass


class KVConnectorMetadata(ABC): # noqa: B024
"""
Abstract Metadata used to communicate between the
Expand Down Expand Up @@ -320,6 +329,18 @@ def get_kv_connector_stats(self) -> Optional["KVConnectorStats"]:
"""
return None

def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
"""
Get the KVConnector handshake metadata for this connector.
This metadata is used for out-of-band connector handshake
between P/D workers.

Returns:
KVConnectorHandshakeMetadata: the handshake metadata.
None if no handshake metadata is available.
"""
return None

# ==============================
# Scheduler-side methods
# ==============================
Expand Down Expand Up @@ -477,6 +498,17 @@ def build_kv_connector_stats(
"""
return None

def set_xfer_handshake_metadata(
self, metadata: dict[int, KVConnectorHandshakeMetadata]
) -> None:
"""
Set the KV connector handshake metadata for this connector.

Args:
metadata (KVConnectorHandshakeMetadata): the handshake metadata to set.
"""
return None

@classmethod
def build_prom_metrics(
cls,
Expand Down
Loading