Skip to content

Commit d6517be

Browse files
GuanLuoNickLucche
andauthored
[Bugfix] Missing NIXL metadata for handshake initialization if instance spans multi-node (#26338)
Signed-off-by: Guan Luo <[email protected]> Signed-off-by: GuanLuo <[email protected]> Signed-off-by: Guan Luo <[email protected]> Co-authored-by: Nicolò Lucchesi <[email protected]>
1 parent 7e06c40 commit d6517be

File tree

7 files changed

+321
-95
lines changed

7 files changed

+321
-95
lines changed

docs/features/nixl_connector_usage.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ python tests/v1/kv_connector/nixl_integration/toy_proxy_server.py \
8181
- Default: 5600
8282
- **Required for both prefiller and decoder instances**
8383
- Each vLLM worker needs a unique port on its host; using the same port number across different hosts is fine
84-
- For TP/DP deployments, each worker's port on a node is computed as: base_port + dp_rank * tp_size + tp_rank (e.g., with `--tensor-parallel-size=4` and base_port=5600, tp_rank 0..3 use ports 5600, 5601, 5602, 5603 on that node).
84+
- For TP/DP deployments, each worker's port on a node is computed as: base_port + dp_rank (e.g., with `--data-parallel-size=2` and base_port=5600, dp_rank 0..1 use port 5600, 5601 on that node).
8585
- Used for the initial NIXL handshake between the prefiller and the decoder
8686

8787
- `VLLM_NIXL_SIDE_CHANNEL_HOST`: Host for side channel communication

tests/v1/kv_connector/unit/test_nixl_connector.py

Lines changed: 102 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
NixlAgentMetadata,
2828
NixlConnector,
2929
NixlConnectorMetadata,
30+
NixlConnectorScheduler,
3031
NixlConnectorWorker,
3132
NixlKVConnectorStats,
3233
)
@@ -283,6 +284,92 @@ def test_prompt_less_than_block_size():
283284
assert len(scheduler_output.scheduled_new_reqs) == 0
284285

285286

287+
@patch(
288+
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
289+
FakeNixlWrapper,
290+
)
291+
def test_kv_transfer_handshake(dist_init):
292+
"""Unit test for basic NixlConnector interface functionality."""
293+
294+
# Test setup, we creates a scheduler that contains a NixlConnector
295+
# of role SCHEDULER, and expect it to be serving NixlAgentMetadata from
296+
# all workers of the instance.
297+
vllm_config = create_vllm_config()
298+
# in case the test runs on non-GPU machine
299+
vllm_config.kv_transfer_config.kv_buffer_device = "cpu"
300+
scheduler = create_scheduler(vllm_config)
301+
302+
# Create two NixlConnector of role WORKER, one is the worker of
303+
# the scheduler (prefill), the other is a worker of decode instance.
304+
305+
# Prefill connector will register KV cache to populate proper handshake
306+
# metadata.
307+
prefill_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
308+
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
309+
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
310+
)
311+
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
312+
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
313+
kv_caches = {
314+
"layer0": shared_tensor,
315+
"layer1": unique_tensor,
316+
"layer2": shared_tensor,
317+
}
318+
prefill_connector.register_kv_caches(kv_caches)
319+
320+
# Simulate EngineCore initialization that would
321+
# gather connector metadata from all workers, the scheduler connector
322+
# expects metadata to be in dict[int, KVConnectorHandshakeMetadata],
323+
# where the first key is the dp_rank, the second key is the tp_rank.
324+
metadata = {0: prefill_connector.get_handshake_metadata()}
325+
scheduler_connector = scheduler.get_kv_connector()
326+
scheduler_connector.set_xfer_handshake_metadata(metadata)
327+
328+
# Simulate a request that finishes prefill, which returns
329+
# corresponding NixlConnectorMetadata for decode instance.
330+
BLOCK_SIZE = vllm_config.cache_config.block_size
331+
NUM_EXTERNAL_FULL_BLOCKS = 2
332+
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
333+
334+
request = create_request(
335+
request_id=1,
336+
block_size=BLOCK_SIZE,
337+
num_tokens=NUM_TOKENS,
338+
do_remote_decode=True,
339+
)
340+
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
341+
delay, kv_connector_metadata = scheduler.get_kv_connector().request_finished(
342+
request, [0, 1, 2]
343+
)
344+
assert delay
345+
346+
# Decode connector will be able to create handshake with the prefill connector.
347+
decode_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
348+
349+
# Here we are testing the retrieval of NIXLAgentMetadata.
350+
# Knowing the implementation detail, we override the add_remote_agent
351+
# to validate the metadata received is the same as the one in prefill_connector.
352+
with patch.object(
353+
decode_connector.connector_worker, "add_remote_agent"
354+
) as mock_add_remote_agent:
355+
mock_add_remote_agent.return_type = "remote_agent"
356+
357+
decode_connector.connector_worker._nixl_handshake(
358+
kv_connector_metadata["remote_host"],
359+
kv_connector_metadata["remote_port"],
360+
kv_connector_metadata["tp_size"],
361+
kv_connector_metadata["remote_engine_id"],
362+
)
363+
364+
received_metadata = mock_add_remote_agent.call_args.args
365+
assert received_metadata[1] == 0 # remote_tp_rank
366+
assert received_metadata[2] == 1 # remote_tp_size
367+
assert metadata[0] == received_metadata[0]
368+
369+
# Need to shutdown the background thread to release NIXL side channel port
370+
scheduler_connector.shutdown()
371+
372+
286373
class FakeNixlConnectorWorker(NixlConnectorWorker):
287374
REMOTE_ENGINE_ID = "remote_engine"
288375

@@ -313,6 +400,7 @@ def _nixl_handshake(
313400
engine_id=self.REMOTE_ENGINE_ID,
314401
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
315402
kv_caches_base_addr=[0],
403+
device_id=0,
316404
num_blocks=1,
317405
block_lens=self.block_len_per_layer,
318406
attn_backend_name=self.backend_name,
@@ -559,6 +647,7 @@ def test_handshake_fails_on_kv_cache_layout_mismatch(self, dist_init):
559647
engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
560648
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
561649
kv_caches_base_addr=[0],
650+
device_id=0,
562651
num_blocks=1,
563652
block_lens=worker.block_len_per_layer,
564653
attn_backend_name=worker.backend_name,
@@ -611,6 +700,7 @@ def test_handshake_succeed_on_kv_cache_layout_mismatch_with_experimental(
611700
engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
612701
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
613702
kv_caches_base_addr=[0],
703+
device_id=0,
614704
num_blocks=1,
615705
# prefill TP=1, decode TP=2, remote block_lens is double to local
616706
block_lens=[i * 2 for i in worker.block_len_per_layer],
@@ -1005,6 +1095,8 @@ def _run_abort_timeout_test(llm: LLM, timeout: int):
10051095
_ = llm.generate([f"What is the capital of France? {padding}"], sampling_params)
10061096
# Request-0 times out and is cleared!
10071097
assert "0" not in req_to_blocks
1098+
# Need to shutdown the background thread to release NIXL side channel port
1099+
llm.llm_engine.engine_core.shutdown()
10081100

10091101

10101102
def test_register_kv_caches(dist_init):
@@ -1177,13 +1269,15 @@ def test_shutdown_cleans_up_resources(dist_init):
11771269
"""Test that shutdown() properly cleans up all resources."""
11781270
vllm_config = create_vllm_config()
11791271

1272+
scheduler = NixlConnectorScheduler(
1273+
vllm_config, vllm_config.kv_transfer_config.engine_id
1274+
)
11801275
worker = NixlConnectorWorker(vllm_config, vllm_config.kv_transfer_config.engine_id)
11811276
nixl_wrapper = worker.nixl_wrapper
11821277

11831278
with (
11841279
patch.object(worker, "_handshake_initiation_executor") as mock_exec,
1185-
patch.object(worker, "_nixl_handshake_listener_t") as mock_listener,
1186-
patch.object(worker, "_nixl_handshake_listener_stop_event") as mock_event,
1280+
patch.object(scheduler, "_nixl_handshake_listener_t") as mock_listener,
11871281
patch.object(nixl_wrapper, "release_xfer_handle") as mock_rel_xfer,
11881282
patch.object(nixl_wrapper, "release_dlist_handle") as mock_rel_dlist,
11891283
patch.object(nixl_wrapper, "remove_remote_agent") as mock_rem_agent,
@@ -1204,8 +1298,12 @@ def test_shutdown_cleans_up_resources(dist_init):
12041298
worker.shutdown()
12051299

12061300
mock_exec.shutdown.assert_called_with(wait=False)
1207-
mock_event.set.assert_called_once()
1208-
mock_listener.join.assert_called_once_with(timeout=1.0)
1301+
1302+
# Same sequence on scheduler.shutdown()
1303+
scheduler.shutdown()
1304+
scheduler.shutdown()
1305+
scheduler.shutdown()
1306+
mock_listener.join.assert_called_once()
12091307

12101308
mock_rel_xfer.assert_called_once_with(123)
12111309
assert mock_rel_dlist.call_count == 2

vllm/distributed/kv_transfer/kv_connector/v1/base.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,15 @@ class KVConnectorRole(enum.Enum):
122122
WORKER = 1
123123

124124

125+
class KVConnectorHandshakeMetadata(ABC): # noqa: B024
126+
"""
127+
Metadata used for out of band connector handshake between
128+
P/D workers. This needs to serializeable.
129+
"""
130+
131+
pass
132+
133+
125134
class KVConnectorMetadata(ABC): # noqa: B024
126135
"""
127136
Abstract Metadata used to communicate between the
@@ -320,6 +329,18 @@ def get_kv_connector_stats(self) -> Optional["KVConnectorStats"]:
320329
"""
321330
return None
322331

332+
def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
333+
"""
334+
Get the KVConnector handshake metadata for this connector.
335+
This metadata is used for out-of-band connector handshake
336+
between P/D workers.
337+
338+
Returns:
339+
KVConnectorHandshakeMetadata: the handshake metadata.
340+
None if no handshake metadata is available.
341+
"""
342+
return None
343+
323344
# ==============================
324345
# Scheduler-side methods
325346
# ==============================
@@ -477,6 +498,17 @@ def build_kv_connector_stats(
477498
"""
478499
return None
479500

501+
def set_xfer_handshake_metadata(
502+
self, metadata: dict[int, KVConnectorHandshakeMetadata]
503+
) -> None:
504+
"""
505+
Set the KV connector handshake metadata for this connector.
506+
507+
Args:
508+
metadata (KVConnectorHandshakeMetadata): the handshake metadata to set.
509+
"""
510+
return None
511+
480512
@classmethod
481513
def build_prom_metrics(
482514
cls,

0 commit comments

Comments
 (0)