2727 NixlAgentMetadata ,
2828 NixlConnector ,
2929 NixlConnectorMetadata ,
30+ NixlConnectorScheduler ,
3031 NixlConnectorWorker ,
3132 NixlKVConnectorStats ,
3233)
@@ -283,6 +284,92 @@ def test_prompt_less_than_block_size():
283284 assert len (scheduler_output .scheduled_new_reqs ) == 0
284285
285286
287+ @patch (
288+ "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper" ,
289+ FakeNixlWrapper ,
290+ )
291+ def test_kv_transfer_handshake (dist_init ):
292+ """Unit test for basic NixlConnector interface functionality."""
293+
294+ # Test setup, we creates a scheduler that contains a NixlConnector
295+ # of role SCHEDULER, and expect it to be serving NixlAgentMetadata from
296+ # all workers of the instance.
297+ vllm_config = create_vllm_config ()
298+ # in case the test runs on non-GPU machine
299+ vllm_config .kv_transfer_config .kv_buffer_device = "cpu"
300+ scheduler = create_scheduler (vllm_config )
301+
302+ # Create two NixlConnector of role WORKER, one is the worker of
303+ # the scheduler (prefill), the other is a worker of decode instance.
304+
305+ # Prefill connector will register KV cache to populate proper handshake
306+ # metadata.
307+ prefill_connector = NixlConnector (vllm_config , KVConnectorRole .WORKER )
308+ kv_cache_shape = FlashAttentionBackend .get_kv_cache_shape (
309+ num_blocks = 2 , block_size = 16 , num_kv_heads = 4 , head_size = 64
310+ )
311+ shared_tensor = torch .zeros (* kv_cache_shape , dtype = torch .float16 )
312+ unique_tensor = torch .zeros (* kv_cache_shape , dtype = torch .float16 )
313+ kv_caches = {
314+ "layer0" : shared_tensor ,
315+ "layer1" : unique_tensor ,
316+ "layer2" : shared_tensor ,
317+ }
318+ prefill_connector .register_kv_caches (kv_caches )
319+
320+ # Simulate EngineCore initialization that would
321+ # gather connector metadata from all workers, the scheduler connector
322+ # expects metadata to be in dict[int, KVConnectorHandshakeMetadata],
323+ # where the first key is the dp_rank, the second key is the tp_rank.
324+ metadata = {0 : prefill_connector .get_handshake_metadata ()}
325+ scheduler_connector = scheduler .get_kv_connector ()
326+ scheduler_connector .set_xfer_handshake_metadata (metadata )
327+
328+ # Simulate a request that finishes prefill, which returns
329+ # corresponding NixlConnectorMetadata for decode instance.
330+ BLOCK_SIZE = vllm_config .cache_config .block_size
331+ NUM_EXTERNAL_FULL_BLOCKS = 2
332+ NUM_TOKENS = int (BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5 ))
333+
334+ request = create_request (
335+ request_id = 1 ,
336+ block_size = BLOCK_SIZE ,
337+ num_tokens = NUM_TOKENS ,
338+ do_remote_decode = True ,
339+ )
340+ request .status = RequestStatus .FINISHED_LENGTH_CAPPED
341+ delay , kv_connector_metadata = scheduler .get_kv_connector ().request_finished (
342+ request , [0 , 1 , 2 ]
343+ )
344+ assert delay
345+
346+ # Decode connector will be able to create handshake with the prefill connector.
347+ decode_connector = NixlConnector (vllm_config , KVConnectorRole .WORKER )
348+
349+ # Here we are testing the retrieval of NIXLAgentMetadata.
350+ # Knowing the implementation detail, we override the add_remote_agent
351+ # to validate the metadata received is the same as the one in prefill_connector.
352+ with patch .object (
353+ decode_connector .connector_worker , "add_remote_agent"
354+ ) as mock_add_remote_agent :
355+ mock_add_remote_agent .return_type = "remote_agent"
356+
357+ decode_connector .connector_worker ._nixl_handshake (
358+ kv_connector_metadata ["remote_host" ],
359+ kv_connector_metadata ["remote_port" ],
360+ kv_connector_metadata ["tp_size" ],
361+ kv_connector_metadata ["remote_engine_id" ],
362+ )
363+
364+ received_metadata = mock_add_remote_agent .call_args .args
365+ assert received_metadata [1 ] == 0 # remote_tp_rank
366+ assert received_metadata [2 ] == 1 # remote_tp_size
367+ assert metadata [0 ] == received_metadata [0 ]
368+
369+ # Need to shutdown the background thread to release NIXL side channel port
370+ scheduler_connector .shutdown ()
371+
372+
286373class FakeNixlConnectorWorker (NixlConnectorWorker ):
287374 REMOTE_ENGINE_ID = "remote_engine"
288375
@@ -313,6 +400,7 @@ def _nixl_handshake(
313400 engine_id = self .REMOTE_ENGINE_ID ,
314401 agent_metadata = FakeNixlWrapper .AGENT_METADATA ,
315402 kv_caches_base_addr = [0 ],
403+ device_id = 0 ,
316404 num_blocks = 1 ,
317405 block_lens = self .block_len_per_layer ,
318406 attn_backend_name = self .backend_name ,
@@ -559,6 +647,7 @@ def test_handshake_fails_on_kv_cache_layout_mismatch(self, dist_init):
559647 engine_id = FakeNixlConnectorWorker .REMOTE_ENGINE_ID ,
560648 agent_metadata = FakeNixlWrapper .AGENT_METADATA ,
561649 kv_caches_base_addr = [0 ],
650+ device_id = 0 ,
562651 num_blocks = 1 ,
563652 block_lens = worker .block_len_per_layer ,
564653 attn_backend_name = worker .backend_name ,
@@ -611,6 +700,7 @@ def test_handshake_succeed_on_kv_cache_layout_mismatch_with_experimental(
611700 engine_id = FakeNixlConnectorWorker .REMOTE_ENGINE_ID ,
612701 agent_metadata = FakeNixlWrapper .AGENT_METADATA ,
613702 kv_caches_base_addr = [0 ],
703+ device_id = 0 ,
614704 num_blocks = 1 ,
615705 # prefill TP=1, decode TP=2, remote block_lens is double to local
616706 block_lens = [i * 2 for i in worker .block_len_per_layer ],
@@ -1005,6 +1095,8 @@ def _run_abort_timeout_test(llm: LLM, timeout: int):
10051095 _ = llm .generate ([f"What is the capital of France? { padding } " ], sampling_params )
10061096 # Request-0 times out and is cleared!
10071097 assert "0" not in req_to_blocks
1098+ # Need to shutdown the background thread to release NIXL side channel port
1099+ llm .llm_engine .engine_core .shutdown ()
10081100
10091101
10101102def test_register_kv_caches (dist_init ):
@@ -1177,13 +1269,15 @@ def test_shutdown_cleans_up_resources(dist_init):
11771269 """Test that shutdown() properly cleans up all resources."""
11781270 vllm_config = create_vllm_config ()
11791271
1272+ scheduler = NixlConnectorScheduler (
1273+ vllm_config , vllm_config .kv_transfer_config .engine_id
1274+ )
11801275 worker = NixlConnectorWorker (vllm_config , vllm_config .kv_transfer_config .engine_id )
11811276 nixl_wrapper = worker .nixl_wrapper
11821277
11831278 with (
11841279 patch .object (worker , "_handshake_initiation_executor" ) as mock_exec ,
1185- patch .object (worker , "_nixl_handshake_listener_t" ) as mock_listener ,
1186- patch .object (worker , "_nixl_handshake_listener_stop_event" ) as mock_event ,
1280+ patch .object (scheduler , "_nixl_handshake_listener_t" ) as mock_listener ,
11871281 patch .object (nixl_wrapper , "release_xfer_handle" ) as mock_rel_xfer ,
11881282 patch .object (nixl_wrapper , "release_dlist_handle" ) as mock_rel_dlist ,
11891283 patch .object (nixl_wrapper , "remove_remote_agent" ) as mock_rem_agent ,
@@ -1204,8 +1298,12 @@ def test_shutdown_cleans_up_resources(dist_init):
12041298 worker .shutdown ()
12051299
12061300 mock_exec .shutdown .assert_called_with (wait = False )
1207- mock_event .set .assert_called_once ()
1208- mock_listener .join .assert_called_once_with (timeout = 1.0 )
1301+
1302+ # Same sequence on scheduler.shutdown()
1303+ scheduler .shutdown ()
1304+ scheduler .shutdown ()
1305+ scheduler .shutdown ()
1306+ mock_listener .join .assert_called_once ()
12091307
12101308 mock_rel_xfer .assert_called_once_with (123 )
12111309 assert mock_rel_dlist .call_count == 2
0 commit comments