[rollout] feat: add uid-affinity server routing for diffusion rollouts#172
[rollout] feat: add uid-affinity server routing for diffusion rollouts#172SamitHuang wants to merge 8 commits into
Conversation
Introduce RolloutServerRoutingConfig and ConfigurableRequestLoadBalancer with policies including prompt_uid_affinity. Wire OmniLLMServerManager into diffusion training and pass routing_key from agent loop sample kwargs. Co-authored-by: GitHub Copilot
Seed integration tests must use OmniLLMServerManager so routing_key can flow through the diffusion agent loop after server routing is enabled. Co-authored-by: GitHub Copilot
routing_key_field already defaults to uid in RolloutServerRoutingConfig and runtime OmegaConf fallbacks; keep policy as the only explicit yaml knob. Co-authored-by: GitHub Copilot
Override the shared server_routing defaults in diffusion_rollout.yaml so FlowGRPO-style multi-copy rollouts co-locate on one replica by default. Co-authored-by: GitHub Copilot
Document prompt_uid_affinity routing for diffusion RL in Advanced Features, following the async_reward doc structure and linking from the docs index. Co-authored-by: GitHub Copilot
…uestLoadBalancer Align the load balancer name with OmniLLMServerManager/Client and update docstrings plus rollout server routing documentation. Co-authored-by: GitHub Copilot
Co-authored-by: GitHub Copilot
Align the advanced-features page title and intro with the PR scope: configurable uid-affinity routing for diffusion rollouts, not new replicas. Co-authored-by: GitHub Copilot
There was a problem hiding this comment.
Code Review
This pull request introduces uid-affinity server routing for diffusion rollouts in VeRL-Omni, allowing related rollout copies to co-locate on the same replica. It adds the OmniRequestLoadBalancer with support for multiple routing policies (such as prompt_uid_affinity and least_inflight), updates the server manager and client to support these policies, and includes corresponding configurations, documentation, and tests. The review feedback suggests preserving the original request_id in OmniLLMServerClient.generate to maintain request tracing, and caching the sorted list of server IDs in OmniRequestLoadBalancer to optimize load balancer throughput when picking servers.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| output: TokenOutput = await server.generate.remote( | ||
| request_id=uuid4().hex, | ||
| prompt_ids=prompt_ids, | ||
| sampling_params=sampling_params, | ||
| image_data=image_data, | ||
| video_data=video_data, | ||
| **multimodal_kwargs, | ||
| **kwargs, | ||
| ) |
There was a problem hiding this comment.
The request_id passed to the generate method is ignored when calling server.generate.remote, and is instead replaced with a newly generated uuid4().hex. This breaks request tracing and correlation between the client/load balancer and the server replicas. Please forward the original request_id to the remote server call.
| output: TokenOutput = await server.generate.remote( | |
| request_id=uuid4().hex, | |
| prompt_ids=prompt_ids, | |
| sampling_params=sampling_params, | |
| image_data=image_data, | |
| video_data=video_data, | |
| **multimodal_kwargs, | |
| **kwargs, | |
| ) | |
| output: TokenOutput = await server.generate.remote( | |
| request_id=request_id, | |
| prompt_ids=prompt_ids, | |
| sampling_params=sampling_params, | |
| image_data=image_data, | |
| video_data=video_data, | |
| **multimodal_kwargs, | |
| **kwargs, | |
| ) |
| self._servers: dict[str, ray.actor.ActorHandle] = dict(servers) | ||
| self._policy = policy | ||
| self._inflight_requests: dict[str, int] = {sid: 0 for sid in servers} | ||
| self._request_id_to_server: LRUCache = LRUCache(maxsize=max_cache_size) | ||
| self._round_robin_idx = 0 |
There was a problem hiding this comment.
To avoid sorting the server keys on every single request in _pick_sharded_server and _pick_round_robin_server, initialize and cache a sorted list of server IDs (self._sorted_server_ids) during initialization and update it when servers are added or removed.
| self._servers: dict[str, ray.actor.ActorHandle] = dict(servers) | |
| self._policy = policy | |
| self._inflight_requests: dict[str, int] = {sid: 0 for sid in servers} | |
| self._request_id_to_server: LRUCache = LRUCache(maxsize=max_cache_size) | |
| self._round_robin_idx = 0 | |
| self._servers: dict[str, ray.actor.ActorHandle] = dict(servers) | |
| self._sorted_server_ids = sorted(self._servers.keys()) | |
| self._policy = policy | |
| self._inflight_requests: dict[str, int] = {sid: 0 for sid in servers} | |
| self._request_id_to_server: LRUCache = LRUCache(maxsize=max_cache_size) | |
| self._round_robin_idx = 0 |
| def _pick_sharded_server(self, sticky_key: str) -> str: | ||
| server_ids = sorted(self._servers.keys()) | ||
| return server_ids[stable_shard_index(sticky_key, len(server_ids))] | ||
|
|
||
| def _pick_round_robin_server(self) -> str: | ||
| server_ids = sorted(self._servers.keys()) | ||
| server_id = server_ids[self._round_robin_idx % len(server_ids)] | ||
| self._round_robin_idx += 1 | ||
| return server_id |
There was a problem hiding this comment.
Use the cached self._sorted_server_ids instead of sorting the server keys on every request to improve load balancer throughput.
| def _pick_sharded_server(self, sticky_key: str) -> str: | |
| server_ids = sorted(self._servers.keys()) | |
| return server_ids[stable_shard_index(sticky_key, len(server_ids))] | |
| def _pick_round_robin_server(self) -> str: | |
| server_ids = sorted(self._servers.keys()) | |
| server_id = server_ids[self._round_robin_idx % len(server_ids)] | |
| self._round_robin_idx += 1 | |
| return server_id | |
| def _pick_sharded_server(self, sticky_key: str) -> str: | |
| return self._sorted_server_ids[stable_shard_index(sticky_key, len(self._sorted_server_ids))] | |
| def _pick_round_robin_server(self) -> str: | |
| server_id = self._sorted_server_ids[self._round_robin_idx % len(self._sorted_server_ids)] | |
| self._round_robin_idx += 1 | |
| return server_id |
| def add_servers(self, servers: dict[str, ray.actor.ActorHandle]) -> None: | ||
| for sid, handle in servers.items(): | ||
| self._inflight_requests[sid] = 0 | ||
| self._servers[sid] = handle | ||
|
|
||
| def remove_servers(self, server_ids: list[str]) -> None: | ||
| for sid in server_ids: | ||
| self._inflight_requests.pop(sid, None) | ||
| self._servers.pop(sid, None) |
There was a problem hiding this comment.
Update the cached self._sorted_server_ids when servers are dynamically added or removed to ensure the list remains accurate.
| def add_servers(self, servers: dict[str, ray.actor.ActorHandle]) -> None: | |
| for sid, handle in servers.items(): | |
| self._inflight_requests[sid] = 0 | |
| self._servers[sid] = handle | |
| def remove_servers(self, server_ids: list[str]) -> None: | |
| for sid in server_ids: | |
| self._inflight_requests.pop(sid, None) | |
| self._servers.pop(sid, None) | |
| def add_servers(self, servers: dict[str, ray.actor.ActorHandle]) -> None: | |
| for sid, handle in servers.items(): | |
| self._inflight_requests[sid] = 0 | |
| self._servers[sid] = handle | |
| self._sorted_server_ids = sorted(self._servers.keys()) | |
| def remove_servers(self, server_ids: list[str]) -> None: | |
| for sid in server_ids: | |
| self._inflight_requests.pop(sid, None) | |
| self._servers.pop(sid, None) | |
| self._sorted_server_ids = sorted(self._servers.keys()) |
|
how did you set the max_seq_num? |
|
default value is 1024 in config, and I think it's large enough |


What does this PR do?
Implement #171
Adds verl-omni multi-replica rollout server routing for async diffusion (and future omni LLM) rollouts.
Upstream
verlusesGlobalRequestLoadBalancer, which routes by per-requestrequest_idwith least-inflight selection. That spreads FlowGRPOrollout.ncopies of the sameprompt across replicas and makes it harder to form dense per-replica batches.
This PR introduces:
OmniRequestLoadBalancer— Ray actor with pluggable policies (prompt_uid_affinity,least_inflight,prompt_hash_sharding,round_robin)OmniLLMServerManager/OmniLLMServerClient— verl-omni server stack that installsOmniRequestLoadBalancerinstead ofGlobalRequestLoadBalancerRolloutServerRoutingConfig+ Hydra defaults (server_routing.yaml, included fromdiffusion_rollout.yaml)uidasrouting_keyprompt_uid_affinity(sharedserver_routing.yamlkeepsleast_inflightfor omni LLM)This PR is routing-only. It does not include vllm-omni request-level diffusion batching (tracked separately).
Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,vllm_omni,rollout,trainer,ci,training_utils,recipe,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,cfg,reward,diffusion,omni,tests,docker,like[diffusion, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][diffusion, fsdp] feat: new rollout schedulerTest
API and Usage Example
# Add code snippet or script demonstrating how to use thisDesign & Code Changes
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always