Skip to content

Conversation

@GuanLuo
Copy link
Contributor

@GuanLuo GuanLuo commented Oct 7, 2025

Purpose

Fix NIXL handshake issue when model instance spans multiple nodes due to parallelism strategy (i.e. TP=16 and run on 2 H100x8), see #25981 for detail

Test Plan

test_nixl_connector.py for unit testing, test_kv_transfer_metadata to verify that NIXLConnectorWorker properly return its handshake metadata and retrieve target metadata; NIXLConnectorScheduler can serve the collective metadata.
nixl_integration/** for integration testing, this covers if EngineCore properly gather and serve the handshake metadata.

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@GuanLuo GuanLuo changed the title fix: missing NIXL metadata for handshake initialization if instance spans multi-node [Bugfix] Missing NIXL metadata for handshake initialization if instance spans multi-node Oct 13, 2025
@mergify
Copy link

mergify bot commented Oct 13, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @GuanLuo.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 13, 2025
@mergify mergify bot removed the needs-rebase label Oct 13, 2025
Copy link
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will look into it more in depth asap, in the meantime tagging other people that may be interested in reviewing @wseaton @markmc @njhill

PS to summarize: listener thread is moved from worker->scheduler. Scheduler aggregates metadata from all workers. Workers carry out handshake by fetching data from Scheduler (single port).

@wseaton
Copy link
Contributor

wseaton commented Oct 13, 2025

This does seem like a good way to get in the collective RPC and scheduler infa changes that #22274 needs, so I am supportive 🙂

@mergify
Copy link

mergify bot commented Oct 13, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @GuanLuo.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 13, 2025
Signed-off-by: Guan Luo <[email protected]>
@GuanLuo GuanLuo requested a review from NickLucche October 24, 2025 05:33
Copy link
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the hard work @GuanLuo !

Let's leave a small window for other reviewers to chime in today, o/w we merge when CI is green again (unrelated failures should be fixed on main first).

# Need to make sure the device ID is non-negative for NIXL,
# Torch uses -1 to indicate CPU tensors while NIXL uses explicit
# memory type.
self.device_id = max(cache.get_device(), 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out of curiosity, what happens if all caches were indeed cpu tensors (not a use-case we have, but good to highlight) ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the cache is CPU tensor, NIXL descriptor registration should still work as intended as self.nixl_memory_type is passed to get_reg_descs which is how NIXL differentiate if the memory is CPU / GPU

@wseaton
Copy link
Contributor

wseaton commented Oct 24, 2025

we basically "squashed" the TP dimension in this PR by grouping across TP workers, but we retain DP separation as we don't aggregated across dp rank (therefore we cleaned up all dp references we had).

SGTM, this shouldn't impact changing handshake strategy in the future.

@NickLucche
Copy link
Collaborator

@GuanLuo good to merge after conflict

@GuanLuo
Copy link
Contributor Author

GuanLuo commented Oct 27, 2025

@NickLucche Resolved, kicked off CI

@NickLucche NickLucche enabled auto-merge (squash) October 29, 2025 13:05
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 29, 2025
auto-merge was automatically disabled October 30, 2025 12:00

Head branch was pushed to by a user without write access

Signed-off-by: Guan Luo <[email protected]>
@GuanLuo GuanLuo requested a review from heheda12345 as a code owner October 30, 2025 12:48
@simon-mo simon-mo merged commit d6517be into vllm-project:main Oct 31, 2025
52 checks passed
xinyu-intel pushed a commit to xinyu-intel/vllm that referenced this pull request Nov 3, 2025
…ce spans multi-node (vllm-project#26338)

Signed-off-by: Guan Luo <[email protected]>
Signed-off-by: GuanLuo <[email protected]>
Signed-off-by: Guan Luo <[email protected]>
Co-authored-by: Nicolò Lucchesi <[email protected]>
xinyu-intel pushed a commit to xinyu-intel/vllm that referenced this pull request Nov 3, 2025
…ce spans multi-node (vllm-project#26338)

Signed-off-by: Guan Luo <[email protected]>
Signed-off-by: GuanLuo <[email protected]>
Signed-off-by: Guan Luo <[email protected]>
Co-authored-by: Nicolò Lucchesi <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation kv-connector ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants