Skip to content

Conversation

wseaton
Copy link
Contributor

@wseaton wseaton commented Aug 5, 2025

Purpose

Adds support for HTTP based metadata interchange between prefill and decode instances of vllm in a "north to south" fashion, using a dedicated API server endpoint as the side channel (completely new entrypoint). This new entrypoint is spawned dynamically based on KVConnector configuration and the value of the new VLLM_NIXL_HANDSHAKE_METHOD method env var.

This PR is a manual rebase of work originally started in #19447. The amount of changes and merge commits made it difficult to rebase incrementally, so I've reimplemented the PR.

During initial review of #19447, it was rightly pointed out by @russellb that this new side-channel route needs additional protections, and should not be part of the existing user facing API server. It has been factored out into the new entrypoint setp.

New features:

  • Adds two new environment variables to envs.py:
    • VLLM_NIXL_HANDSHAKE_METHOD: "zmq" or "http", defaults to "zmq" to preserve legacy behaviour
    • VLLM_NIXL_HANDSHAKE_TIMEOUT: an http timeout to apply so the handshake doesn't cause blocking indefinitely. The handshake runs in a background thread, but we don't want to wait forever for it to complete.

Moved to another PR:

  • Support for TLS configuration of the sidechannel. Uses the same TLS config as the server itself, supporting a pod level ssl termination setup like what would be desirable in k8s.

Test Plan

Using the pd_examples justfile here: https://github.com/wseaton/pd_examples/blob/dp-experiments/Justfile#L41-L126, install vllm in a venv with nixl==0.5.1.

Spin up 4 different tmux panes w/ various decode TP sizes:

  • just prefill
  • just decode
  • just proxy
  • just send_request

Copy link

github-actions bot commented Aug 5, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link

mergify bot commented Aug 5, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @wseaton.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Aug 5, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for HTTP/S-based metadata exchange for the NixlConnector, replacing the previous ZMQ implementation. This is a significant architectural improvement that enhances security and simplifies deployment in containerized environments like Kubernetes.

The introduction of a dedicated SSLConfig class is a great step towards better configuration management, and the use of the strategy pattern for handshake mechanisms (ZmqHandshakeStrategy and HttpHandshakeStrategy) makes the code more modular and extensible.

I've identified a couple of high-severity issues that should be addressed before merging. One is related to error handling during server startup, and the other concerns a potential logic flaw in the ZMQ handshake strategy. Please see the detailed comments below.

Overall, this is a well-structured and valuable contribution. Great work!

Comment on lines 152 to 170
# Handshake with remote agent-rank0 first to get the tp_size of remote
path = make_zmq_path("tcp", host, port)
logger.debug("Querying master rank metadata on path: %s", path)
metadata, agent_name_0 = handshake(path, 0)

agents = {0: agent_name_0}

# Handshake only with the other TP remote the current local rank will
# pull from. With homogeneous TP it happens to be the same rank_i.
tp_ratio = self._tp_size[self.engine_id] // remote_tp_size
p_remote_rank = self.tp_rank // tp_ratio
if p_remote_rank > 0:
path = make_zmq_path("tcp", host, port + p_remote_rank)
logger.debug("Querying metadata on path: %s at remote rank %s",
path, p_remote_rank)
_, agent_name = handshake(path, p_remote_rank)
agents[p_remote_rank] = agent_name

return agents
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The handshake logic in ZmqHandshakeStrategy appears to be unnecessarily complex and potentially incorrect. It always performs a handshake with remote rank 0, and then another with the calculated p_remote_rank if it's not 0. This is inconsistent with the HttpHandshakeStrategy and the previous implementation, which only interact with the required p_remote_rank.

This extra handshake is inefficient and could lead to performance issues. The comment on line 152 is also misleading, as remote_tp_size is already provided as an argument and not derived from the handshake.

The logic should be simplified to perform a single handshake with the target p_remote_rank.

Suggested change
# Handshake with remote agent-rank0 first to get the tp_size of remote
path = make_zmq_path("tcp", host, port)
logger.debug("Querying master rank metadata on path: %s", path)
metadata, agent_name_0 = handshake(path, 0)
agents = {0: agent_name_0}
# Handshake only with the other TP remote the current local rank will
# pull from. With homogeneous TP it happens to be the same rank_i.
tp_ratio = self._tp_size[self.engine_id] // remote_tp_size
p_remote_rank = self.tp_rank // tp_ratio
if p_remote_rank > 0:
path = make_zmq_path("tcp", host, port + p_remote_rank)
logger.debug("Querying metadata on path: %s at remote rank %s",
path, p_remote_rank)
_, agent_name = handshake(path, p_remote_rank)
agents[p_remote_rank] = agent_name
return agents
# Handshake only with the remote TP rank that the current local rank will
# pull from. With homogeneous TP, it happens to be the same rank_i.
tp_ratio = self._tp_size[self.engine_id] // remote_tp_size
p_remote_rank = self.tp_rank // tp_ratio
path = make_zmq_path("tcp", host, port + p_remote_rank)
logger.debug("Querying metadata on path: %s at remote rank %s",
path, p_remote_rank)
_, agent_name = handshake(path, p_remote_rank)
return {p_remote_rank: agent_name}

@wseaton wseaton force-pushed the kv-connector-overhaul branch from 1042cb7 to cb7146b Compare August 5, 2025 21:02
@mergify mergify bot removed the needs-rebase label Aug 5, 2025
@wseaton
Copy link
Contributor Author

wseaton commented Aug 5, 2025

cc @NickLucche @njhill

Copy link
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a big one I will need another pass.
If we can find things to factor out into separate independent PRs it would be awesome. Eg SSL?

@wseaton wseaton force-pushed the kv-connector-overhaul branch from 3dd9f79 to c3c4b06 Compare August 7, 2025 00:23
@wseaton
Copy link
Contributor Author

wseaton commented Aug 7, 2025

Hmm, looks like somewhere along the way we lost het TP correctness, trying to figure out what went wrong there.

Edit: main with the bugfix patch from 80b7562 doesn't have it either, so might want to ignore for the sake of this PR:

diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py
index 1da41790f..a4efc8664 100644
--- a/vllm/distributed/kv_transfer/kv_connector/utils.py
+++ b/vllm/distributed/kv_transfer/kv_connector/utils.py
@@ -142,12 +142,13 @@ class KVOutputAggregator:
         finished_sending = set[str]()
         finished_recving = set[str]()
         for output in outputs:
-            output = output.kv_connector_output
-            update_finished_set(output.finished_sending,
-                                self._send_remaining_count, finished_sending)
-            update_finished_set(output.finished_recving,
-                                self._recv_remaining_count, finished_recving)
-
+            if (kv_output := output.kv_connector_output) is not None:
+                update_finished_set(kv_output.finished_sending,
+                                    self._send_remaining_count,
+                                    finished_sending)
+                update_finished_set(kv_output.finished_recving,
+                                    self._recv_remaining_count,
+                                    finished_recving)
         # select output of the worker specified by output_rank
         output = outputs[output_rank]

@wseaton wseaton changed the title Draft: NixlConnector Support HTTP/S metadata exchange instead of zmq NixlConnector Support HTTP/S metadata exchange instead of zmq Aug 7, 2025
Copy link

mergify bot commented Aug 8, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @wseaton.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Aug 8, 2025
@NickLucche
Copy link
Collaborator

NickLucche commented Aug 9, 2025

looks like somewhere along the way we lost het TP correctness

Sorry I lost this one, is hetTP broken somewhere I can look at?

Copy link

mergify bot commented Sep 4, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @wseaton.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link

mergify bot commented Sep 19, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @wseaton.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 19, 2025
@wseaton wseaton force-pushed the kv-connector-overhaul branch from 1ee2d36 to f0c3298 Compare September 22, 2025 12:40
@wseaton wseaton requested a review from ApostaC as a code owner September 22, 2025 12:40
@wseaton wseaton force-pushed the kv-connector-overhaul branch from f0c3298 to 94414db Compare September 22, 2025 12:43
wseaton and others added 4 commits September 22, 2025 08:51
Signed-off-by: Will Eaton <[email protected]>
Co-authored-by: NickLucche <[email protected]>
Signed-off-by: Will Eaton <[email protected]>
Co-authored-by: Nick Hill <[email protected]>
Signed-off-by: Will Eaton <[email protected]>
make timeout longer; implement new shutdown api

Signed-off-by: Will Eaton <[email protected]>
@wseaton wseaton force-pushed the kv-connector-overhaul branch from 94414db to ea721c8 Compare September 22, 2025 12:57
@mergify mergify bot removed the needs-rebase label Sep 22, 2025
Signed-off-by: Will Eaton <[email protected]>
Signed-off-by: Will Eaton <[email protected]>
Copy link

mergify bot commented Sep 22, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @wseaton.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 22, 2025
@GuanLuo
Copy link
Contributor

GuanLuo commented Sep 27, 2025

Hi, I took a look at this PR as I was told this may resolve the issue of incorrect NIXL handshake when the model parallelism spans multi-node (#19080 tried to address it but didn't get to finish line).

This PR looks promising and I just want to summarize the exchange logic to verify my understanding:

  1. Regardless of the parallelism strategy, The collective RPC get_kv_connector_handshake_metadata will collect the NIXL metadata from each of the workers (even across cluster node) and return via engine_client's get_kv_handshake_metadata.
  2. Once obtained the collective metadata, on head node, a HTTP server listening to the SIDE_CHANNEL_HOST / PORT will be spawned and serve the metadata. SIDE_CHANNEL_HOST / PORT will be encapsulated in prefill response as what is done before, and decode engine will use that to reach the HTTP server.
  3. Standard procedure where each worker in decode engine will reach the HTTP server for metadata of desired prefill workers and perform handshake / block transfer.

@wseaton
Copy link
Contributor Author

wseaton commented Oct 7, 2025

I will pick this back up in the next day or two and break it into smaller pieces, starting with the handshake abstractions for zmq.

@wseaton
Copy link
Contributor Author

wseaton commented Oct 13, 2025

Closing for now, we have a path forward in other PRs!

@wseaton wseaton closed this Oct 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants