Skip to content

[rollout] feat: add uid-affinity server routing for diffusion rollouts#172

Draft
SamitHuang wants to merge 8 commits into
verl-project:mainfrom
SamitHuang:feat/rollout-server-routing
Draft

[rollout] feat: add uid-affinity server routing for diffusion rollouts#172
SamitHuang wants to merge 8 commits into
verl-project:mainfrom
SamitHuang:feat/rollout-server-routing

Conversation

@SamitHuang

Copy link
Copy Markdown
Collaborator

What does this PR do?

Implement #171

Adds verl-omni multi-replica rollout server routing for async diffusion (and future omni LLM) rollouts.
Upstream verl uses GlobalRequestLoadBalancer, which routes by per-request request_id with least-inflight selection. That spreads FlowGRPO rollout.n copies of the same
prompt across replicas and makes it harder to form dense per-replica batches.
This PR introduces:

  • OmniRequestLoadBalancer — Ray actor with pluggable policies (prompt_uid_affinity, least_inflight, prompt_hash_sharding, round_robin)
  • OmniLLMServerManager / OmniLLMServerClient — verl-omni server stack that installs OmniRequestLoadBalancer instead of GlobalRequestLoadBalancer
  • RolloutServerRoutingConfig + Hydra defaults (server_routing.yaml, included from diffusion_rollout.yaml)
  • Agent-loop wiring to pass per-sample uid as routing_key
  • Diffusion default policy: prompt_uid_affinity (shared server_routing.yaml keeps least_inflight for omni LLM)
    This PR is routing-only. It does not include vllm-omni request-level diffusion batching (tracked separately).

Checklist Before Starting

  • Search for similar PRs. Paste at least one query link here: ...
  • Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)
    • {modules} include fsdp, vllm_omni, rollout, trainer, ci, training_utils, recipe, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data, cfg, reward, diffusion, omni, tests, docker
    • If this PR involves multiple modules, separate them with , like [diffusion, doc]
    • {type} is in feat, fix, refactor, chore, test
    • If this PR breaks any API (CLI arguments, config, function signature, etc.), add [BREAKING] to the beginning of the title.
    • Example: [BREAKING][diffusion, fsdp] feat: new rollout scheduler

Test

For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc.

API and Usage Example

Demonstrate how the API changes if any, and provide usage example(s) if possible.

# Add code snippet or script demonstrating how to use this

Design & Code Changes

Demonstrate the high-level design if this PR is complex, and list the specific changes.

Checklist Before Submitting

Important

Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

  • Read the Contribute Guide.
  • Apply pre-commit checks: pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always
  • Add / Update the documentation.
  • Add unit or end-to-end test(s) to the CI workflow to cover all the code. If not feasible, explain why: ...

Introduce RolloutServerRoutingConfig and ConfigurableRequestLoadBalancer
with policies including prompt_uid_affinity. Wire OmniLLMServerManager
into diffusion training and pass routing_key from agent loop sample kwargs.

Co-authored-by: GitHub Copilot
Seed integration tests must use OmniLLMServerManager so routing_key can
flow through the diffusion agent loop after server routing is enabled.

Co-authored-by: GitHub Copilot
routing_key_field already defaults to uid in RolloutServerRoutingConfig
and runtime OmegaConf fallbacks; keep policy as the only explicit yaml knob.

Co-authored-by: GitHub Copilot
Override the shared server_routing defaults in diffusion_rollout.yaml so
FlowGRPO-style multi-copy rollouts co-locate on one replica by default.

Co-authored-by: GitHub Copilot
Document prompt_uid_affinity routing for diffusion RL in Advanced Features,
following the async_reward doc structure and linking from the docs index.

Co-authored-by: GitHub Copilot
…uestLoadBalancer

Align the load balancer name with OmniLLMServerManager/Client and update
docstrings plus rollout server routing documentation.

Co-authored-by: GitHub Copilot
Align the advanced-features page title and intro with the PR scope:
configurable uid-affinity routing for diffusion rollouts, not new replicas.

Co-authored-by: GitHub Copilot
@SamitHuang SamitHuang requested a review from knlnguyen1802 June 14, 2026 13:29

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces uid-affinity server routing for diffusion rollouts in VeRL-Omni, allowing related rollout copies to co-locate on the same replica. It adds the OmniRequestLoadBalancer with support for multiple routing policies (such as prompt_uid_affinity and least_inflight), updates the server manager and client to support these policies, and includes corresponding configurations, documentation, and tests. The review feedback suggests preserving the original request_id in OmniLLMServerClient.generate to maintain request tracing, and caching the sorted list of server IDs in OmniRequestLoadBalancer to optimize load balancer throughput when picking servers.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +74 to +82
output: TokenOutput = await server.generate.remote(
request_id=uuid4().hex,
prompt_ids=prompt_ids,
sampling_params=sampling_params,
image_data=image_data,
video_data=video_data,
**multimodal_kwargs,
**kwargs,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The request_id passed to the generate method is ignored when calling server.generate.remote, and is instead replaced with a newly generated uuid4().hex. This breaks request tracing and correlation between the client/load balancer and the server replicas. Please forward the original request_id to the remote server call.

Suggested change
output: TokenOutput = await server.generate.remote(
request_id=uuid4().hex,
prompt_ids=prompt_ids,
sampling_params=sampling_params,
image_data=image_data,
video_data=video_data,
**multimodal_kwargs,
**kwargs,
)
output: TokenOutput = await server.generate.remote(
request_id=request_id,
prompt_ids=prompt_ids,
sampling_params=sampling_params,
image_data=image_data,
video_data=video_data,
**multimodal_kwargs,
**kwargs,
)

Comment on lines +61 to +65
self._servers: dict[str, ray.actor.ActorHandle] = dict(servers)
self._policy = policy
self._inflight_requests: dict[str, int] = {sid: 0 for sid in servers}
self._request_id_to_server: LRUCache = LRUCache(maxsize=max_cache_size)
self._round_robin_idx = 0

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To avoid sorting the server keys on every single request in _pick_sharded_server and _pick_round_robin_server, initialize and cache a sorted list of server IDs (self._sorted_server_ids) during initialization and update it when servers are added or removed.

Suggested change
self._servers: dict[str, ray.actor.ActorHandle] = dict(servers)
self._policy = policy
self._inflight_requests: dict[str, int] = {sid: 0 for sid in servers}
self._request_id_to_server: LRUCache = LRUCache(maxsize=max_cache_size)
self._round_robin_idx = 0
self._servers: dict[str, ray.actor.ActorHandle] = dict(servers)
self._sorted_server_ids = sorted(self._servers.keys())
self._policy = policy
self._inflight_requests: dict[str, int] = {sid: 0 for sid in servers}
self._request_id_to_server: LRUCache = LRUCache(maxsize=max_cache_size)
self._round_robin_idx = 0

Comment on lines +89 to +97
def _pick_sharded_server(self, sticky_key: str) -> str:
server_ids = sorted(self._servers.keys())
return server_ids[stable_shard_index(sticky_key, len(server_ids))]

def _pick_round_robin_server(self) -> str:
server_ids = sorted(self._servers.keys())
server_id = server_ids[self._round_robin_idx % len(server_ids)]
self._round_robin_idx += 1
return server_id

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Use the cached self._sorted_server_ids instead of sorting the server keys on every request to improve load balancer throughput.

Suggested change
def _pick_sharded_server(self, sticky_key: str) -> str:
server_ids = sorted(self._servers.keys())
return server_ids[stable_shard_index(sticky_key, len(server_ids))]
def _pick_round_robin_server(self) -> str:
server_ids = sorted(self._servers.keys())
server_id = server_ids[self._round_robin_idx % len(server_ids)]
self._round_robin_idx += 1
return server_id
def _pick_sharded_server(self, sticky_key: str) -> str:
return self._sorted_server_ids[stable_shard_index(sticky_key, len(self._sorted_server_ids))]
def _pick_round_robin_server(self) -> str:
server_id = self._sorted_server_ids[self._round_robin_idx % len(self._sorted_server_ids)]
self._round_robin_idx += 1
return server_id

Comment on lines +121 to +129
def add_servers(self, servers: dict[str, ray.actor.ActorHandle]) -> None:
for sid, handle in servers.items():
self._inflight_requests[sid] = 0
self._servers[sid] = handle

def remove_servers(self, server_ids: list[str]) -> None:
for sid in server_ids:
self._inflight_requests.pop(sid, None)
self._servers.pop(sid, None)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Update the cached self._sorted_server_ids when servers are dynamically added or removed to ensure the list remains accurate.

Suggested change
def add_servers(self, servers: dict[str, ray.actor.ActorHandle]) -> None:
for sid, handle in servers.items():
self._inflight_requests[sid] = 0
self._servers[sid] = handle
def remove_servers(self, server_ids: list[str]) -> None:
for sid in server_ids:
self._inflight_requests.pop(sid, None)
self._servers.pop(sid, None)
def add_servers(self, servers: dict[str, ray.actor.ActorHandle]) -> None:
for sid, handle in servers.items():
self._inflight_requests[sid] = 0
self._servers[sid] = handle
self._sorted_server_ids = sorted(self._servers.keys())
def remove_servers(self, server_ids: list[str]) -> None:
for sid in server_ids:
self._inflight_requests.pop(sid, None)
self._servers.pop(sid, None)
self._sorted_server_ids = sorted(self._servers.keys())

@knlnguyen1802

Copy link
Copy Markdown
Collaborator

Test this PR on top of continuous batching
Result on main branch
timing_gen_trend
Result on this branch
timing_gen_trend_route

@SamitHuang

Copy link
Copy Markdown
Collaborator Author

how did you set the max_seq_num?

@knlnguyen1802

knlnguyen1802 commented Jun 16, 2026

Copy link
Copy Markdown
Collaborator

default value is 1024 in config, and I think it's large enough
and it can be set in rollout config too

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants