Skip to content

Commit 03df0fb

Browse files
LucasWilkinsonsimon-mo
authored andcommitted
[BugFix] Fix DP/EP hang (#25906)
Signed-off-by: Lucas Wilkinson <[email protected]> Signed-off-by: simon-mo <[email protected]>
1 parent 9471879 commit 03df0fb

File tree

1 file changed

+17
-4
lines changed

1 file changed

+17
-4
lines changed

vllm/v1/worker/gpu_model_runner.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2989,13 +2989,19 @@ def _dummy_run(
29892989
# We currently only microbatch if the number of tokens is
29902990
# over a certain threshold.
29912991
if self.parallel_config.enable_dbo and allow_microbatching:
2992-
ubatch_slices, num_tokens_after_padding = ubatch_split(
2992+
ubatch_slices, ubatch_num_tokens_after_padding = ubatch_split(
29932993
num_scheduled_tokens,
29942994
total_num_scheduled_tokens,
29952995
total_num_scheduled_tokens,
29962996
uniform_decode=uniform_decode,
29972997
vllm_config=self.vllm_config,
29982998
)
2999+
# Currently when DBO is enabled `ubatch_split` returns
3000+
# the num_tokens_after_padding for a single ubatch, but we have 2
3001+
# TODO(sage,lucas): this is cruft that should be addressed in the
3002+
# padding refactor.
3003+
if ubatch_num_tokens_after_padding is not None:
3004+
num_tokens_after_padding = ubatch_num_tokens_after_padding * 2
29993005

30003006
# If we failed to microbatch, currently need to resynchronize
30013007
# TODO(lucas,sage): we should be able to avoid this second sync by
@@ -3112,8 +3118,9 @@ def _dummy_run(
31123118

31133119
# filter out the valid batch descriptor
31143120
_cg_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch(
3115-
BatchDescriptor(num_tokens=num_tokens,
3116-
uniform_decode=uniform_decode))
3121+
BatchDescriptor(num_tokens=num_tokens_after_padding,
3122+
uniform_decode=uniform_decode)) \
3123+
if not is_profile else (CUDAGraphMode.NONE, None)
31173124
if cudagraph_runtime_mode is not None:
31183125
# we allow forcing NONE when the dispatcher disagrees to support
31193126
# warm ups for cudagraph capture
@@ -3125,7 +3132,13 @@ def _dummy_run(
31253132
cudagraph_runtime_mode = _cg_mode
31263133

31273134
if ubatch_slices is not None:
3128-
num_tokens = num_tokens // 2
3135+
# Adjust values to reflect a single ubatch.
3136+
# TODO(sage,lucas): this is cruft that should be addressed in
3137+
# the padding refactor.
3138+
num_tokens_after_padding = ubatch_slices[0].num_tokens
3139+
if num_tokens_across_dp is not None:
3140+
num_tokens_across_dp[:] = num_tokens_after_padding
3141+
31293142
with self.maybe_randomize_inputs(input_ids), set_forward_context(
31303143
attn_metadata,
31313144
self.vllm_config,

0 commit comments

Comments
 (0)