Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
5770410
fix: missing NIXL metadata for handshake initialization if instance s…
GuanLuo Oct 7, 2025
8fb478a
test: add test
GuanLuo Oct 10, 2025
f5e82f5
chore: style and fix test
GuanLuo Oct 10, 2025
aa41ea5
chore: fix pre-commit
GuanLuo Oct 10, 2025
d3abc46
fix: fix test
GuanLuo Oct 10, 2025
813dab1
fix: fix NIXL handshake listener cleanup
GuanLuo Oct 10, 2025
009ec48
fix: fix device ID use for NIXL memory registration
GuanLuo Oct 11, 2025
9e2c914
Merge branch 'main' into gluo/nixl
GuanLuo Oct 13, 2025
e52448f
chore: address comment
GuanLuo Oct 14, 2025
102f543
Merge branch 'main' into gluo/nixl
GuanLuo Oct 14, 2025
fa8357f
fix: fix test and CPU case
GuanLuo Oct 20, 2025
a535327
Merge branch 'main' into gluo/nixl
GuanLuo Oct 21, 2025
bb9f577
chore: address comment
GuanLuo Oct 22, 2025
aeb56df
Merge branch 'main' into gluo/nixl
GuanLuo Oct 22, 2025
c49b0bb
style: style
GuanLuo Oct 22, 2025
7cc5a2d
fix: fix test
GuanLuo Oct 22, 2025
84c6130
chore: address comment
GuanLuo Oct 24, 2025
13dce7c
Merge branch 'main' into gluo/nixl
GuanLuo Oct 24, 2025
7009553
fix: use mock nixl lib
GuanLuo Oct 24, 2025
47f08f4
Merge branch 'main' into gluo/nixl
NickLucche Oct 24, 2025
c042ef8
Merge branch 'main' into gluo/nixl
GuanLuo Oct 27, 2025
ea11194
Merge branch 'main' into gluo/nixl
NickLucche Oct 27, 2025
a20e248
Merge branch 'main' into gluo/nixl
GuanLuo Oct 28, 2025
b769ee5
Merge branch 'main' into gluo/nixl
GuanLuo Oct 29, 2025
c728545
remove addition newline
GuanLuo Oct 29, 2025
970d1c4
Merge branch 'main' into gluo/nixl
GuanLuo Oct 30, 2025
ee1d4af
fix: fix pre_commit
GuanLuo Oct 30, 2025
552b0fb
fix: fix up
GuanLuo Oct 30, 2025
da43425
Merge branch 'main' into gluo/nixl
GuanLuo Oct 30, 2025
696dd7b
fix: fix up
GuanLuo Oct 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 100 additions & 2 deletions tests/v1/kv_connector/unit/test_nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
NixlAgentMetadata,
NixlConnector,
NixlConnectorMetadata,
NixlConnectorScheduler,
NixlConnectorWorker,
NixlKVConnectorStats,
)
Expand Down Expand Up @@ -284,6 +285,93 @@ def test_prompt_less_than_block_size():
assert len(scheduler_output.scheduled_new_reqs) == 0


def test_kv_transfer_metadata(dist_init):
"""Unit test for basic NixlConnector interface functionality."""

# Test setup, we creates a scheduler that contains a NixlConnector
# of role SCHEDULER, and expect it to be serving NixlAgentMetadata from
# all workers of the instance.
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)

# Create two NixlConnector of role WORKER, one is the worker of
# the scheduler (prefill), the other is a worker of decode instance.

# Prefill connector will register KV cache to populate proper handshake
# metadata.
prefill_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
"layer2": shared_tensor,
}
prefill_connector.register_kv_caches(kv_caches)

# Simulate EngineCore initialization that would
# gather connector metadata from all workers, the scheduler connector
# expects metadata to be in Dict[int, Dict[int, KVConnectorHandshakeMetadata]],
# where the first key is the dp_rank, the second key is the tp_rank.
metadata = {0: {0: prefill_connector.get_handshake_metadata()}}
scheduler_connector = scheduler.get_kv_connector()
scheduler_connector.set_xfer_handshake_metadata(metadata)

# Simulate a request that finishes prefill, which returns
# corresponding NixlConnectorMetadata for decode instance.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))

request = create_request(
request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_decode=True,
)
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
delay, kv_connector_metadata = scheduler.get_kv_connector().request_finished(
request, [0, 1, 2]
)
assert delay

# Decode connector will be able to create handshake with the prefill connector.
decode_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)

# Here we are testing the retrieval of NIXLAgentMetadata.
# Knowing the implementation detail, we override the add_remote_agent
# to validate the metadata received is the same as the one in prefill_connector.
received_metadata = None

def mock_add_remote_agent(
agent_metadata: NixlAgentMetadata,
remote_tp_rank: int,
remote_tp_size: int,
):
nonlocal received_metadata
received_metadata = (agent_metadata, remote_tp_rank, remote_tp_size)
return "remote_agent"

decode_connector.connector_worker.add_remote_agent = mock_add_remote_agent

decode_connector.connector_worker._nixl_handshake(
kv_connector_metadata["remote_host"],
kv_connector_metadata["remote_port"],
kv_connector_metadata["tp_size"],
kv_connector_metadata["remote_engine_id"],
)
assert received_metadata is not None
assert received_metadata[1] == 0 # remote_tp_rank
assert received_metadata[2] == 1 # remote_tp_size
assert metadata[0][0] == received_metadata[0]

# Need to shutdown the background thread to release NIXL side channel port
scheduler_connector.shutdown()


class FakeNixlConnectorWorker(NixlConnectorWorker):
REMOTE_ENGINE_ID = "remote_engine"

Expand Down Expand Up @@ -874,6 +962,8 @@ def _run_abort_timeout_test(llm_kwargs: dict, timeout: int):
_ = llm.generate([f"What is the capital of France? {padding}"], sampling_params)
# Request-0 times out and is cleared!
assert "0" not in req_to_blocks
# Need to shutdown the background thread to release NIXL side channel port
llm.llm_engine.engine_core.shutdown()


def test_register_kv_caches(dist_init):
Expand Down Expand Up @@ -1043,12 +1133,15 @@ def test_shutdown_cleans_up_resources(dist_init):
"""Test that shutdown() properly cleans up all resources."""
vllm_config = create_vllm_config()

scheduler = NixlConnectorScheduler(
vllm_config, vllm_config.kv_transfer_config.engine_id
)
worker = NixlConnectorWorker(vllm_config, vllm_config.kv_transfer_config.engine_id)
nixl_wrapper = worker.nixl_wrapper

with (
patch.object(worker, "_handshake_initiation_executor") as mock_exec,
patch.object(worker, "_nixl_handshake_listener_t") as mock_listener,
patch.object(scheduler, "_nixl_handshake_listener_t") as mock_listener,
patch.object(nixl_wrapper, "release_xfer_handle") as mock_rel_xfer,
patch.object(nixl_wrapper, "release_dlist_handle") as mock_rel_dlist,
patch.object(nixl_wrapper, "remove_remote_agent") as mock_rem_agent,
Expand All @@ -1067,7 +1160,12 @@ def test_shutdown_cleans_up_resources(dist_init):
worker.shutdown()

mock_exec.shutdown.assert_called_with(wait=False)
mock_listener.join.assert_called_once_with(timeout=0)

# Same sequence on scheduler.shutdown()
scheduler.shutdown()
scheduler.shutdown()
scheduler.shutdown()
mock_listener.join.assert_called_once()

mock_rel_xfer.assert_called_once_with(123)
assert mock_rel_dlist.call_count == 2
Expand Down
30 changes: 30 additions & 0 deletions vllm/distributed/kv_transfer/kv_connector/v1/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,15 @@ class KVConnectorRole(enum.Enum):
WORKER = 1


class KVConnectorHandshakeMetadata(ABC): # noqa: B024
"""
Metadata used for out of band connector handshakeandshake between
P/D workers. This needs to serializeable.
"""

pass


class KVConnectorMetadata(ABC): # noqa: B024
"""
Abstract Metadata used to communicate between the
Expand Down Expand Up @@ -272,6 +281,18 @@ def get_kv_connector_stats(self) -> Optional["KVConnectorStats"]:
"""
return None

def get_handshake_metadata(self) -> Optional[KVConnectorHandshakeMetadata]:
"""
Get the KVConnector handshake metadata for this connector.
This metadata is used for out-of-band connector handshake
between P/D workers.

Returns:
KVConnectorHandshakeMetadata: the handshake metadata.
None if no handshake metadata is available.
"""
return None

# ==============================
# Scheduler-side methods
# ==============================
Expand Down Expand Up @@ -427,3 +448,12 @@ def build_kv_connector_stats(
which can implement custom aggregation logic on the data dict.
"""
return None

def set_xfer_handshake_metadata(self, metadata: dict[int, dict[int, dict]]) -> None:
"""
Set the KV connector handshake metadata for this connector.

Args:
metadata (dict): the handshake metadata to set.
"""
return None
Loading