diff --git a/vllm_spyre/v1/worker/spyre_model_runner.py b/vllm_spyre/v1/worker/spyre_model_runner.py index be57fee6c..8e051c808 100644 --- a/vllm_spyre/v1/worker/spyre_model_runner.py +++ b/vllm_spyre/v1/worker/spyre_model_runner.py @@ -308,8 +308,17 @@ def build_input_batch(self) -> SamplingInputBatch: # TODO(Max): logits processor list should be extensible via engine # constructor argument; for now the list is fixed to builtin processors logits_processors = get_builtin_logits_processors(self.vllm_config) + + # fix for fp8 batch size 1 (continuous batching): set input batch to + # fit at minimum 2 requests for the dynamic warmup and then reset input + # batch to fit max_num_seqs requests after warmup has completed + min_seqs_required = 2 if (envs_spyre.VLLM_SPYRE_USE_CB + and self.model_config.quantization + and self.warmup_mode) else 1 + return SamplingInputBatch( - max_num_reqs=self.scheduler_config.max_num_seqs, + max_num_reqs=max(min_seqs_required, + self.scheduler_config.max_num_seqs), max_model_len=self.model_config.max_model_len, device=self.device, pin_memory=self.pin_memory, @@ -841,6 +850,9 @@ def pre_warmup(self) -> None: def complete_warmup(self) -> None: super().complete_warmup() + # Fix for fp8 batch size 1: update the input_batch after warmup + if self.model_config.quantization: + self.input_batch = self.build_input_batch() # get the number or pages from the actual Spyre card after the warmup # and set it accordingly in the model runner and for the kv cache size n_blocks_avail = self.model.model.get_num_blocks_available()