Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion vllm_spyre/v1/worker/spyre_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,17 @@ def build_input_batch(self) -> SamplingInputBatch:
# TODO(Max): logits processor list should be extensible via engine
# constructor argument; for now the list is fixed to builtin processors
logits_processors = get_builtin_logits_processors(self.vllm_config)

# fix for fp8 batch size 1 (continuous batching): set input batch to
# fit at minimum 2 requests for the dynamic warmup and then reset input
# batch to fit max_num_seqs requests after warmup has completed
min_seqs_required = 2 if (envs_spyre.VLLM_SPYRE_USE_CB
and self.model_config.quantization
and self.warmup_mode) else 1

return SamplingInputBatch(
max_num_reqs=self.scheduler_config.max_num_seqs,
max_num_reqs=max(min_seqs_required,
self.scheduler_config.max_num_seqs),
max_model_len=self.model_config.max_model_len,
device=self.device,
pin_memory=self.pin_memory,
Expand Down Expand Up @@ -841,6 +850,9 @@ def pre_warmup(self) -> None:

def complete_warmup(self) -> None:
super().complete_warmup()
# Fix for fp8 batch size 1: update the input_batch after warmup
if self.model_config.quantization:
self.input_batch = self.build_input_batch()
# get the number or pages from the actual Spyre card after the warmup
# and set it accordingly in the model runner and for the kv cache size
n_blocks_avail = self.model.model.get_num_blocks_available()
Expand Down