@@ -2989,13 +2989,19 @@ def _dummy_run(
2989
2989
# We currently only microbatch if the number of tokens is
2990
2990
# over a certain threshold.
2991
2991
if self .parallel_config .enable_dbo and allow_microbatching :
2992
- ubatch_slices , num_tokens_after_padding = ubatch_split (
2992
+ ubatch_slices , ubatch_num_tokens_after_padding = ubatch_split (
2993
2993
num_scheduled_tokens ,
2994
2994
total_num_scheduled_tokens ,
2995
2995
total_num_scheduled_tokens ,
2996
2996
uniform_decode = uniform_decode ,
2997
2997
vllm_config = self .vllm_config ,
2998
2998
)
2999
+ # Currently when DBO is enabled `ubatch_split` returns
3000
+ # the num_tokens_after_padding for a single ubatch, but we have 2
3001
+ # TODO(sage,lucas): this is cruft that should be addressed in the
3002
+ # padding refactor.
3003
+ if ubatch_num_tokens_after_padding is not None :
3004
+ num_tokens_after_padding = ubatch_num_tokens_after_padding * 2
2999
3005
3000
3006
# If we failed to microbatch, currently need to resynchronize
3001
3007
# TODO(lucas,sage): we should be able to avoid this second sync by
@@ -3112,8 +3118,9 @@ def _dummy_run(
3112
3118
3113
3119
# filter out the valid batch descriptor
3114
3120
_cg_mode , batch_descriptor = self .cudagraph_dispatcher .dispatch (
3115
- BatchDescriptor (num_tokens = num_tokens ,
3116
- uniform_decode = uniform_decode ))
3121
+ BatchDescriptor (num_tokens = num_tokens_after_padding ,
3122
+ uniform_decode = uniform_decode )) \
3123
+ if not is_profile else (CUDAGraphMode .NONE , None )
3117
3124
if cudagraph_runtime_mode is not None :
3118
3125
# we allow forcing NONE when the dispatcher disagrees to support
3119
3126
# warm ups for cudagraph capture
@@ -3125,7 +3132,13 @@ def _dummy_run(
3125
3132
cudagraph_runtime_mode = _cg_mode
3126
3133
3127
3134
if ubatch_slices is not None :
3128
- num_tokens = num_tokens // 2
3135
+ # Adjust values to reflect a single ubatch.
3136
+ # TODO(sage,lucas): this is cruft that should be addressed in
3137
+ # the padding refactor.
3138
+ num_tokens_after_padding = ubatch_slices [0 ].num_tokens
3139
+ if num_tokens_across_dp is not None :
3140
+ num_tokens_across_dp [:] = num_tokens_after_padding
3141
+
3129
3142
with self .maybe_randomize_inputs (input_ids ), set_forward_context (
3130
3143
attn_metadata ,
3131
3144
self .vllm_config ,
0 commit comments