From 80361bd7b08427680067bc297202213da233c177 Mon Sep 17 00:00:00 2001 From: Wallas Santos Date: Thu, 9 Oct 2025 21:18:44 -0300 Subject: [PATCH 01/16] feat: improve tests Signed-off-by: Wallas Santos --- tests/e2e/test_sampling_params.py | 17 +-- .../v1/sample/spyre_logits_processor.py | 133 ++++++++++++++++-- vllm_spyre/v1/worker/spyre_input_batch.py | 6 +- vllm_spyre/v1/worker/spyre_model_runner.py | 2 +- 4 files changed, 137 insertions(+), 21 deletions(-) diff --git a/tests/e2e/test_sampling_params.py b/tests/e2e/test_sampling_params.py index 7f2e6d3f..ac34dcf4 100644 --- a/tests/e2e/test_sampling_params.py +++ b/tests/e2e/test_sampling_params.py @@ -249,15 +249,17 @@ def test_spyre_batch1_logit_bias(model: ModelInfo, backend, monkeypatch, def test_spyre_batch1_min_tokens(model: ModelInfo, backend, monkeypatch, - use_llm_cache, warmup_shapes): + use_llm_cache, max_model_len, max_num_seqs, + warmup_shapes, cb: int): spyre_model = get_cached_llm( model=model, - max_model_len=128, + max_model_len=max_model_len, tensor_parallel_size=1, backend=backend, monkeypatch=monkeypatch, - warmup_shapes=warmup_shapes, - ) + warmup_shapes=warmup_shapes if cb != 1 else None, + max_num_seqs=max_num_seqs if cb == 1 else None, + use_cb=cb==1) prompt = "What is the capital of the USA?" tokenizer = spyre_model.get_tokenizer() eos_id = tokenizer.eos_token_id @@ -268,11 +270,10 @@ def test_spyre_batch1_min_tokens(model: ModelInfo, backend, monkeypatch, max_tokens=20) params2 = SamplingParams(seed=8780, logit_bias={eos_id: 50}, max_tokens=20) - output1 = spyre_model.generate(prompt, params1)[0] - output2 = spyre_model.generate(prompt, params2)[0] + output = spyre_model.generate([prompt] * 2, [params1, params2]) - assert len(output1.outputs[0].token_ids) >= 19 - assert len(output2.outputs[0].token_ids) < 19 + assert len(output[0].outputs[0].token_ids) >= 19 + assert len(output[1].outputs[0].token_ids) < 19 def test_spyre_batch1_ignore_eos(model: ModelInfo, backend, monkeypatch, diff --git a/vllm_spyre/v1/sample/spyre_logits_processor.py b/vllm_spyre/v1/sample/spyre_logits_processor.py index ec6621f7..f445355d 100644 --- a/vllm_spyre/v1/sample/spyre_logits_processor.py +++ b/vllm_spyre/v1/sample/spyre_logits_processor.py @@ -1,20 +1,29 @@ import itertools -from typing import Optional, Sequence, Union +from typing import Optional, Sequence, Union, TYPE_CHECKING import torch -from vllm.config import VllmConfig + from vllm.logger import init_logger -from vllm.v1.sample.logits_processor import (BUILTIN_LOGITS_PROCESSORS, - STR_POOLING_REJECTS_LOGITSPROCS, - BatchUpdate, LogitsProcessor, - _load_custom_logitsprocs) +from vllm.v1.sample.logits_processor import ( + BUILTIN_LOGITS_PROCESSORS, STR_POOLING_REJECTS_LOGITSPROCS, BatchUpdate, + LogitsProcessor, MinPLogitsProcessor, LogitBiasLogitsProcessor, + MinTokensLogitsProcessor, _load_custom_logitsprocs) from vllm.v1.sample.logits_processor.state import LogitsProcessors logger = init_logger(__name__) +if TYPE_CHECKING: + from vllm import SamplingParams + from vllm.config import VllmConfig +else: + SamplingParams = None + VllmConfig = None + +assert len(BUILTIN_LOGITS_PROCESSORS) == 3 + def build_logitsprocs_for_cb( - vllm_config: "VllmConfig", + vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool, is_pooling_model: bool, @@ -30,7 +39,7 @@ def build_logitsprocs_for_cb( custom_logitsprocs_classes = _load_custom_logitsprocs(custom_logitsprocs) return LogitsProcessors( - LogitProcessorWrapper(logit_processor, + LogitsProcessorWrapper(logit_processor, vllm_config, device, is_pin_memory, @@ -42,7 +51,13 @@ def build_logitsprocs_for_cb( ) -class LogitProcessorWrapper(LogitsProcessor): +class SpyreLogitsProcessor: + + def set_prefill_index(self, idx: int) -> None: + raise NotImplementedError + + +class LogitsProcessorWrapper(LogitsProcessor, SpyreLogitsProcessor): """Logit processor to inject expected token during generation for tests""" def __init__(self, logit_processor: LogitsProcessor, @@ -101,3 +116,103 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: def set_prefill_index(self, idx: int) -> None: self._prefill_index = idx + + +class SpyreMinPLogitsProcessor(LogitsProcessor, SpyreLogitsProcessor): + + def __init__(self, vllm_config: "VllmConfig", device: torch.device, + is_pin_memory: bool): + super().__init__(vllm_config, device, is_pin_memory) + self.prefill_index: Optional[int] = None + + def apply(self, logits: torch.Tensor) -> torch.Tensor: + + if not self.min_p_count: + return logits + + # if self.prefill_index is not None: + # pass + + # # Convert logits to probability distribution + # probability_values = torch.nn.functional.softmax(logits, dim=-1) + # # Calculate maximum probabilities per sequence + # max_probabilities = torch.amax(probability_values, + # dim=-1, + # keepdim=True) + # # Adjust min_p + # adjusted_min_p = max_probabilities.mul_(self.min_p) + # # Identify valid tokens using threshold comparison + # invalid_token_mask = probability_values < adjusted_min_p + # # Apply mask using boolean indexing + # logits[invalid_token_mask] = -float('inf') + # return logits + + +class SpyreLogitBiasLogitsProcessor(LogitBiasLogitsProcessor, + SpyreLogitsProcessor): + + def __init__(self, _, device: torch.device, is_pin_memory: bool): + self.device = device + self.pin_memory = is_pin_memory + self.biases: dict[int, dict[int, float]] = {} + + self.bias_tensor: torch.Tensor = torch.tensor(()) + self.logits_slice = (self._device_tensor([], torch.int32), + self._device_tensor([], torch.int32)) + + def is_argmax_invariant(self) -> bool: + """Logit bias can rebalance token probabilities and change the + outcome of argmax in greedy sampling.""" + return False + + def update_state(self, batch_update: Optional[BatchUpdate]): + needs_update = process_dict_updates( + self.biases, batch_update, + lambda params, _, __: params.logit_bias or None) + + # Update tensors if needed. + if needs_update: + reqs: list[int] = [] + tok_ids: list[int] = [] + biases: list[float] = [] + for req, lb in self.biases.items(): + reqs.extend([req] * len(lb)) + tok_ids.extend(lb.keys()) + biases.extend(lb.values()) + + self.bias_tensor = self._device_tensor(biases, torch.float32) + self.logits_slice = (self._device_tensor(reqs, torch.int32), + self._device_tensor(tok_ids, torch.int32)) + + def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor: + return (torch.tensor(data, + device="cpu", + dtype=dtype, + pin_memory=self.pin_memory).to(device=self.device, + non_blocking=True)) + + def apply(self, logits: torch.Tensor) -> torch.Tensor: + if self.biases: + logits[self.logits_slice] += self.bias_tensor + return logits + + +class SpyreMinTokensLogitsProcessor(MinTokensLogitsProcessor, + SpyreLogitsProcessor): + + def __init__(self, vllm_config: VllmConfig, device: torch.device, + is_pin_memory: bool): + super().__init__(vllm_config, device, is_pin_memory) + self._prefill_index : Optional[int] = None + + def set_prefill_index(self, idx: int) -> None: + + + raise NotImplementedError + + + def apply(self, logits: torch.Tensor) -> torch.Tensor: + if self._prefill_index is not None: + pass + + return super().apply(self, logits) diff --git a/vllm_spyre/v1/worker/spyre_input_batch.py b/vllm_spyre/v1/worker/spyre_input_batch.py index fbdd0d15..89a1f3c6 100644 --- a/vllm_spyre/v1/worker/spyre_input_batch.py +++ b/vllm_spyre/v1/worker/spyre_input_batch.py @@ -18,7 +18,7 @@ MoveDirectionality) from vllm.v1.sample.metadata import SamplingMetadata -from vllm_spyre.v1.sample.spyre_logits_processor import LogitProcessorWrapper +from vllm_spyre.v1.sample.spyre_logits_processor import SpyreLogitsProcessor @dataclass @@ -301,8 +301,8 @@ def __init__(self, self.batch_update_builder = BatchUpdateBuilder() self.logitsprocs = logitsprocs or LogitsProcessors() - self.logitsprocs_wrappers = [lp for lp \ - in self.logitsprocs.all if isinstance(lp, LogitProcessorWrapper) + self.spyre_logitsprocs = [lp for lp \ + in self.logitsprocs.all if isinstance(lp, SpyreLogitsProcessor) ] self.has_allowed_token_ids: set[str] = set() diff --git a/vllm_spyre/v1/worker/spyre_model_runner.py b/vllm_spyre/v1/worker/spyre_model_runner.py index 150ce912..cf8ba5a4 100644 --- a/vllm_spyre/v1/worker/spyre_model_runner.py +++ b/vllm_spyre/v1/worker/spyre_model_runner.py @@ -1015,7 +1015,7 @@ def _prepare_prompt( self.prefill_batch.add_request(req_state) # set prefill index for logits processor - for logitsproc in self.input_batch.logitsprocs_wrappers: + for logitsproc in self.input_batch.spyre_logitsprocs: logitsproc.set_prefill_index(prefill_index) # Refresh sampling metadata after all request are added to the batch From 27bf0f532cef734fdadd84723e72bd4819060eb7 Mon Sep 17 00:00:00 2001 From: Wallas Santos Date: Tue, 14 Oct 2025 10:52:20 -0300 Subject: [PATCH 02/16] test: test_mintokens_logits_processor Signed-off-by: Wallas Santos --- tests/e2e/test_sampling_params.py | 6 +- tests/utils/test_spyre_logits_processor.py | 145 ++++++++++++++++++ .../v1/sample/spyre_logits_processor.py | 55 +++++-- vllm_spyre/v1/worker/spyre_model_runner.py | 9 ++ 4 files changed, 197 insertions(+), 18 deletions(-) create mode 100644 tests/utils/test_spyre_logits_processor.py diff --git a/tests/e2e/test_sampling_params.py b/tests/e2e/test_sampling_params.py index ac34dcf4..53db3ceb 100644 --- a/tests/e2e/test_sampling_params.py +++ b/tests/e2e/test_sampling_params.py @@ -249,7 +249,7 @@ def test_spyre_batch1_logit_bias(model: ModelInfo, backend, monkeypatch, def test_spyre_batch1_min_tokens(model: ModelInfo, backend, monkeypatch, - use_llm_cache, max_model_len, max_num_seqs, + use_llm_cache, max_model_len, max_num_seqs, warmup_shapes, cb: int): spyre_model = get_cached_llm( model=model, @@ -259,7 +259,7 @@ def test_spyre_batch1_min_tokens(model: ModelInfo, backend, monkeypatch, monkeypatch=monkeypatch, warmup_shapes=warmup_shapes if cb != 1 else None, max_num_seqs=max_num_seqs if cb == 1 else None, - use_cb=cb==1) + use_cb=cb == 1) prompt = "What is the capital of the USA?" tokenizer = spyre_model.get_tokenizer() eos_id = tokenizer.eos_token_id @@ -267,7 +267,7 @@ def test_spyre_batch1_min_tokens(model: ModelInfo, backend, monkeypatch, params1 = SamplingParams(min_tokens=19, logit_bias={eos_id: 50}, seed=8780, - max_tokens=20) + max_tokens=25) params2 = SamplingParams(seed=8780, logit_bias={eos_id: 50}, max_tokens=20) output = spyre_model.generate([prompt] * 2, [params1, params2]) diff --git a/tests/utils/test_spyre_logits_processor.py b/tests/utils/test_spyre_logits_processor.py new file mode 100644 index 00000000..0d35fb6c --- /dev/null +++ b/tests/utils/test_spyre_logits_processor.py @@ -0,0 +1,145 @@ +from vllm_spyre.v1.sample.spyre_logits_processor import SpyreMinPLogitsProcessor, SpyreMinTokensLogitsProcessor, SpyreLogitsProcessor + +import math +from vllm.config import VllmConfig +from vllm.v1.sample.logits_processor import BatchUpdate +from vllm.sampling_params import SamplingParams +from vllm.v1.sample.logits_processor import BatchUpdateBuilder +import pytest +import torch + +EOS_TOKEN = 3 +VOCAB_SIZE = 8 + +class DummyVllmConfig: + + def __init__(self): + self.scheduler_config = DummySchedulerConfig() + +class DummySchedulerConfig: + + def __init__(self, max_num_seqs=1): + self.max_num_seqs = max_num_seqs + + +def generate_logits(batch_size: int =1): + + return torch.tensor(list(range(VOCAB_SIZE)) * batch_size, + dtype=torch.float32).reshape((batch_size, VOCAB_SIZE) ) + + +def prefill(params: SamplingParams, + batch_update_builder : BatchUpdateBuilder, + lp : SpyreLogitsProcessor, + logits : torch.Tensor, + ouput_tokens : list[int], + req_idx : int, + num_reqs: int): + params._all_stop_token_ids = set([EOS_TOKEN]) # + prompt_tokens = [0] * 8 + batch_update_builder.added.append((req_idx, params, prompt_tokens, ouput_tokens)) + batch_update = batch_update_builder.get_and_reset(num_reqs) + lp.update_state(batch_update) + + lp.set_prefill_index(req_idx) + out_logits = lp.apply(logits.clone()) + ouput_tokens.append(0) # just append a random token + + return out_logits + +def decode(batch_update_builder : BatchUpdateBuilder, + lp : SpyreLogitsProcessor, + logits : torch.Tensor, + batch_ouput_tokens : list[list[int]]): + + assert logits.shape[0] == len(batch_ouput_tokens) + + # This is called at each execute model in spyre model runner update_states + lp.update_state(None) + + out_logits = lp.apply(logits.clone()) + + for output_tokens in batch_ouput_tokens: + output_tokens.append(0) # just append a random token + + return out_logits + +@pytest.mark.cpu +@pytest.mark.worker +def test_mintokens_logits_processor(): + ''' + Tests the builtin SpyreMinTokensLogitsProcessor, + ''' + + device = torch.device('cpu') + + dummy_config = DummyVllmConfig() + lp = SpyreMinTokensLogitsProcessor(dummy_config, device, False) + + batch_update_builder = BatchUpdateBuilder() + + batch_output_tokens = [[], [], []] + + + # Step #0 Prefill req_id #0 (no min tokens) + logits = generate_logits(1) + out_logits = prefill(SamplingParams(), + batch_update_builder, + lp, + logits, + batch_output_tokens[0], + req_idx=0, + num_reqs=1) + + assert torch.allclose(logits, out_logits) # Do nothing + + # Step #1 Prefill req_id #1 (with min tokens) + params = SamplingParams(min_tokens=4) + + logits = generate_logits(1) + out_logits = prefill(params, + batch_update_builder, + lp, + logits, + batch_output_tokens[1], + req_idx=1, + num_reqs=2) + + assert out_logits[0][EOS_TOKEN].item() == -math.inf + + # Step #2 Prefill req_id #1 + logits = generate_logits(1) + out_logits = prefill(SamplingParams(), + batch_update_builder, + lp, + logits, + batch_output_tokens[2], + req_idx=2, + num_reqs=3) + + assert torch.allclose(logits, out_logits) # Do nothing + + # Step #3 - #6 Decoding, eos_token for req #1 must be -inf + for _ in range(3): + logits = generate_logits(3) + out_logits = decode(batch_update_builder, + lp, + logits, + batch_output_tokens, + ) + + assert torch.allclose(logits[0], out_logits[0]) + assert torch.allclose(logits[2], out_logits[2]) + assert out_logits[1][EOS_TOKEN].item() == -math.inf + + + # Step #7, min tokens reached, no changes in logits anymore + logits = generate_logits(3) + out_logits = decode(batch_update_builder, + lp, + logits, + batch_output_tokens, + ) + + assert torch.allclose(logits, out_logits) + diff --git a/vllm_spyre/v1/sample/spyre_logits_processor.py b/vllm_spyre/v1/sample/spyre_logits_processor.py index f445355d..206bfe91 100644 --- a/vllm_spyre/v1/sample/spyre_logits_processor.py +++ b/vllm_spyre/v1/sample/spyre_logits_processor.py @@ -1,13 +1,17 @@ import itertools -from typing import Optional, Sequence, Union, TYPE_CHECKING +from typing import TYPE_CHECKING, Optional, Sequence, Union import torch - from vllm.logger import init_logger -from vllm.v1.sample.logits_processor import ( - BUILTIN_LOGITS_PROCESSORS, STR_POOLING_REJECTS_LOGITSPROCS, BatchUpdate, - LogitsProcessor, MinPLogitsProcessor, LogitBiasLogitsProcessor, - MinTokensLogitsProcessor, _load_custom_logitsprocs) +from vllm.v1.sample.logits_processor import (BUILTIN_LOGITS_PROCESSORS, + STR_POOLING_REJECTS_LOGITSPROCS, + BatchUpdate, + LogitBiasLogitsProcessor, + LogitsProcessor, + MinPLogitsProcessor, + MinTokensLogitsProcessor, + _load_custom_logitsprocs, + process_dict_updates) from vllm.v1.sample.logits_processor.state import LogitsProcessors logger = init_logger(__name__) @@ -118,7 +122,7 @@ def set_prefill_index(self, idx: int) -> None: self._prefill_index = idx -class SpyreMinPLogitsProcessor(LogitsProcessor, SpyreLogitsProcessor): +class SpyreMinPLogitsProcessor(MinPLogitsProcessor, SpyreLogitsProcessor): def __init__(self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool): @@ -203,16 +207,37 @@ class SpyreMinTokensLogitsProcessor(MinTokensLogitsProcessor, def __init__(self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool): super().__init__(vllm_config, device, is_pin_memory) - self._prefill_index : Optional[int] = None + self._prefill_slice : Optional[tuple[torch.Tensor, torch.Tensor]] \ + = None + self._is_prefill: bool = False def set_prefill_index(self, idx: int) -> None: - - - raise NotImplementedError - + + reqs: list[int] = [] + tok_ids: list[int] = [] + for req, (_, _, stop_tok_ids) in self.min_toks.items(): + if req == idx: + # NOTE: always request 0 for prefill + # logits will only have logits for a single request + reqs.extend([0] * len(stop_tok_ids)) + tok_ids.extend(stop_tok_ids) + + if reqs and tok_ids: + self._prefill_slice = (self._device_tensor(reqs, torch.int32), + self._device_tensor(tok_ids, torch.int32)) + self._is_prefill = True def apply(self, logits: torch.Tensor) -> torch.Tensor: - if self._prefill_index is not None: - pass + if self._prefill_slice is not None: + logits[self._prefill_slice] = -float("inf") + self._prefill_slice = None + self._is_prefill = False + return logits + elif self._is_prefill: + # It is prefill but we do not need to do anything + # for the prefill request, just return logits to + # avoid slice the logits with batch_size = 1 + self._is_prefill = False + return logits - return super().apply(self, logits) + return super().apply(logits) diff --git a/vllm_spyre/v1/worker/spyre_model_runner.py b/vllm_spyre/v1/worker/spyre_model_runner.py index cf8ba5a4..3866279e 100644 --- a/vllm_spyre/v1/worker/spyre_model_runner.py +++ b/vllm_spyre/v1/worker/spyre_model_runner.py @@ -404,6 +404,9 @@ def update_states(self, scheduler_output: SchedulerOutput): # of logitprocs. Refactor so that we can batch removals to the # `input_batch` self.input_batch.refresh_metadata() + else: + # Due to logits processor we need to refresh metadata at each step + self.input_batch.refresh_metadata() def _get_prompt_logprobs_dict( self, @@ -1367,6 +1370,12 @@ def build_input_batch(self) -> SamplingInputBatch: is_pooling_model=False, custom_logitsprocs=custom_logitsprocs, batch_size=batch_size) + # logits_processors = \ + # build_logitsprocs(vllm_config=self.vllm_config, + # device=self.device, + # is_pin_memory=self.pin_memory, + # is_pooling_model=False, + # custom_logitsprocs=custom_logitsprocs) return SamplingInputBatch( max_num_reqs=batch_size, From 89a9fbb599637edc64189a50e9ddebb9302a2041 Mon Sep 17 00:00:00 2001 From: Wallas Santos Date: Tue, 14 Oct 2025 11:03:17 -0300 Subject: [PATCH 03/16] test: add docs and code cleanup Signed-off-by: Wallas Santos --- tests/utils/test_spyre_logits_processor.py | 122 +++++++++++---------- 1 file changed, 62 insertions(+), 60 deletions(-) diff --git a/tests/utils/test_spyre_logits_processor.py b/tests/utils/test_spyre_logits_processor.py index 0d35fb6c..7a40a17a 100644 --- a/tests/utils/test_spyre_logits_processor.py +++ b/tests/utils/test_spyre_logits_processor.py @@ -1,78 +1,81 @@ -from vllm_spyre.v1.sample.spyre_logits_processor import SpyreMinPLogitsProcessor, SpyreMinTokensLogitsProcessor, SpyreLogitsProcessor - import math -from vllm.config import VllmConfig -from vllm.v1.sample.logits_processor import BatchUpdate -from vllm.sampling_params import SamplingParams -from vllm.v1.sample.logits_processor import BatchUpdateBuilder + import pytest import torch +from vllm.sampling_params import SamplingParams +from vllm.v1.sample.logits_processor import BatchUpdateBuilder + +from vllm_spyre.v1.sample.spyre_logits_processor import ( + SpyreLogitsProcessor, SpyreMinTokensLogitsProcessor) EOS_TOKEN = 3 VOCAB_SIZE = 8 + class DummyVllmConfig: def __init__(self): self.scheduler_config = DummySchedulerConfig() + class DummySchedulerConfig: - + def __init__(self, max_num_seqs=1): self.max_num_seqs = max_num_seqs -def generate_logits(batch_size: int =1): +def generate_logits(batch_size: int = 1): - return torch.tensor(list(range(VOCAB_SIZE)) * batch_size, - dtype=torch.float32).reshape((batch_size, VOCAB_SIZE) ) + return torch.tensor(list(range(VOCAB_SIZE)) * batch_size, + dtype=torch.float32).reshape((batch_size, VOCAB_SIZE)) -def prefill(params: SamplingParams, - batch_update_builder : BatchUpdateBuilder, - lp : SpyreLogitsProcessor, - logits : torch.Tensor, - ouput_tokens : list[int], - req_idx : int, - num_reqs: int): - params._all_stop_token_ids = set([EOS_TOKEN]) # +def prefill(params: SamplingParams, batch_update_builder: BatchUpdateBuilder, + lp: SpyreLogitsProcessor, logits: torch.Tensor, + ouput_tokens: list[int], req_idx: int, num_reqs: int): + params._all_stop_token_ids = set([EOS_TOKEN]) # prompt_tokens = [0] * 8 - batch_update_builder.added.append((req_idx, params, prompt_tokens, ouput_tokens)) + batch_update_builder.added.append( + (req_idx, params, prompt_tokens, ouput_tokens)) batch_update = batch_update_builder.get_and_reset(num_reqs) lp.update_state(batch_update) lp.set_prefill_index(req_idx) out_logits = lp.apply(logits.clone()) - ouput_tokens.append(0) # just append a random token + ouput_tokens.append(0) # just append a random token return out_logits -def decode(batch_update_builder : BatchUpdateBuilder, - lp : SpyreLogitsProcessor, - logits : torch.Tensor, - batch_ouput_tokens : list[list[int]]): - + +def decode(batch_update_builder: BatchUpdateBuilder, lp: SpyreLogitsProcessor, + logits: torch.Tensor, batch_ouput_tokens: list[list[int]]): + assert logits.shape[0] == len(batch_ouput_tokens) # This is called at each execute model in spyre model runner update_states - lp.update_state(None) + lp.update_state(None) out_logits = lp.apply(logits.clone()) for output_tokens in batch_ouput_tokens: - output_tokens.append(0) # just append a random token + output_tokens.append(0) # just append a random token return out_logits - + + @pytest.mark.cpu @pytest.mark.worker def test_mintokens_logits_processor(): - ''' - Tests the builtin SpyreMinTokensLogitsProcessor, - ''' - + """ + This method tests the SpyreMinTokensLogitsProcessor class. + + The test case simulates partially the step of the engine with focus on + logits processor only. + The logits processor should mark the EOS token as -inf once it reaches + the minimum number of tokens specified in the SamplingParams. + """ device = torch.device('cpu') - + dummy_config = DummyVllmConfig() lp = SpyreMinTokensLogitsProcessor(dummy_config, device, False) @@ -80,27 +83,26 @@ def test_mintokens_logits_processor(): batch_output_tokens = [[], [], []] - # Step #0 Prefill req_id #0 (no min tokens) logits = generate_logits(1) - out_logits = prefill(SamplingParams(), + out_logits = prefill(SamplingParams(), batch_update_builder, - lp, - logits, + lp, + logits, batch_output_tokens[0], req_idx=0, num_reqs=1) - assert torch.allclose(logits, out_logits) # Do nothing + assert torch.allclose(logits, out_logits) # Do nothing # Step #1 Prefill req_id #1 (with min tokens) params = SamplingParams(min_tokens=4) - + logits = generate_logits(1) - out_logits = prefill(params, + out_logits = prefill(params, batch_update_builder, - lp, - logits, + lp, + logits, batch_output_tokens[1], req_idx=1, num_reqs=2) @@ -109,37 +111,37 @@ def test_mintokens_logits_processor(): # Step #2 Prefill req_id #1 logits = generate_logits(1) - out_logits = prefill(SamplingParams(), + out_logits = prefill(SamplingParams(), batch_update_builder, - lp, - logits, + lp, + logits, batch_output_tokens[2], req_idx=2, num_reqs=3) - assert torch.allclose(logits, out_logits) # Do nothing + assert torch.allclose(logits, out_logits) # Do nothing # Step #3 - #6 Decoding, eos_token for req #1 must be -inf for _ in range(3): logits = generate_logits(3) - out_logits = decode(batch_update_builder, - lp, - logits, - batch_output_tokens, - ) - + out_logits = decode( + batch_update_builder, + lp, + logits, + batch_output_tokens, + ) + assert torch.allclose(logits[0], out_logits[0]) assert torch.allclose(logits[2], out_logits[2]) assert out_logits[1][EOS_TOKEN].item() == -math.inf - # Step #7, min tokens reached, no changes in logits anymore logits = generate_logits(3) - out_logits = decode(batch_update_builder, - lp, - logits, - batch_output_tokens, - ) - + out_logits = decode( + batch_update_builder, + lp, + logits, + batch_output_tokens, + ) + assert torch.allclose(logits, out_logits) - From b74ba58530bd3818592696f118a994daaddc8d70 Mon Sep 17 00:00:00 2001 From: Wallas Santos Date: Tue, 14 Oct 2025 12:20:47 -0300 Subject: [PATCH 04/16] feat: added tests for logits bias Signed-off-by: Wallas Santos --- tests/e2e/test_sampling_params.py | 24 ++-- tests/utils/test_spyre_logits_processor.py | 73 +++++++++++- .../v1/sample/spyre_logits_processor.py | 110 +++++++++++------- vllm_spyre/v1/worker/spyre_model_runner.py | 14 +-- 4 files changed, 154 insertions(+), 67 deletions(-) diff --git a/tests/e2e/test_sampling_params.py b/tests/e2e/test_sampling_params.py index 53db3ceb..fc9b6d02 100644 --- a/tests/e2e/test_sampling_params.py +++ b/tests/e2e/test_sampling_params.py @@ -209,16 +209,21 @@ def test_spyre_batch1_top_k(model: ModelInfo, backend, monkeypatch, assert token_div1 < token_div2 +# def test_spyre_batch1_min_tokens(model: ModelInfo, backend, monkeypatch, +# use_llm_cache, max_model_len, max_num_seqs, +# warmup_shapes, cb: int): def test_spyre_batch1_logit_bias(model: ModelInfo, backend, monkeypatch, - use_llm_cache, warmup_shapes): + use_llm_cache, warmup_shapes, max_model_len, + max_num_seqs, cb: int): spyre_model = get_cached_llm( model=model, - max_model_len=128, + max_model_len=max_model_len, + max_num_seqs=max_num_seqs, tensor_parallel_size=1, backend=backend, monkeypatch=monkeypatch, - warmup_shapes=warmup_shapes, - ) + warmup_shapes=warmup_shapes if cb == 0 else None, + use_cb=cb == 1) tokenizer = spyre_model.get_tokenizer() banned_word = "train" forced_word = "plane" @@ -239,13 +244,12 @@ def test_spyre_batch1_logit_bias(model: ModelInfo, backend, monkeypatch, }) params2 = SamplingParams(temperature=0, seed=8780, max_tokens=5) - output1 = spyre_model.generate(prompt, params1)[0] - output2 = spyre_model.generate(prompt, params2)[0] + output = spyre_model.generate([prompt, prompt], [params1, params2]) - assert banned_word not in output1.outputs[0].text.lower() - assert forced_word in output1.outputs[0].text.lower() + assert banned_word not in output[0].outputs[0].text.lower() + assert forced_word in output[0].outputs[0].text.lower() - assert output1.outputs[0].text != output2.outputs[0].text + assert output[0].outputs[0].text != output[1].outputs[0].text def test_spyre_batch1_min_tokens(model: ModelInfo, backend, monkeypatch, @@ -267,7 +271,7 @@ def test_spyre_batch1_min_tokens(model: ModelInfo, backend, monkeypatch, params1 = SamplingParams(min_tokens=19, logit_bias={eos_id: 50}, seed=8780, - max_tokens=25) + max_tokens=20) params2 = SamplingParams(seed=8780, logit_bias={eos_id: 50}, max_tokens=20) output = spyre_model.generate([prompt] * 2, [params1, params2]) diff --git a/tests/utils/test_spyre_logits_processor.py b/tests/utils/test_spyre_logits_processor.py index 7a40a17a..d0eeeb6e 100644 --- a/tests/utils/test_spyre_logits_processor.py +++ b/tests/utils/test_spyre_logits_processor.py @@ -6,7 +6,7 @@ from vllm.v1.sample.logits_processor import BatchUpdateBuilder from vllm_spyre.v1.sample.spyre_logits_processor import ( - SpyreLogitsProcessor, SpyreMinTokensLogitsProcessor) + SpyreLogitsProcessor, SpyreMinTokensLogitsProcessor, SpyreLogitBiasLogitsProcessor) EOS_TOKEN = 3 VOCAB_SIZE = 8 @@ -93,7 +93,7 @@ def test_mintokens_logits_processor(): req_idx=0, num_reqs=1) - assert torch.allclose(logits, out_logits) # Do nothing + assert torch.allclose(logits, out_logits) # Step #1 Prefill req_id #1 (with min tokens) params = SamplingParams(min_tokens=4) @@ -109,7 +109,7 @@ def test_mintokens_logits_processor(): assert out_logits[0][EOS_TOKEN].item() == -math.inf - # Step #2 Prefill req_id #1 + # Step #2 Prefill req_id #2 logits = generate_logits(1) out_logits = prefill(SamplingParams(), batch_update_builder, @@ -119,7 +119,7 @@ def test_mintokens_logits_processor(): req_idx=2, num_reqs=3) - assert torch.allclose(logits, out_logits) # Do nothing + assert torch.allclose(logits, out_logits) # Step #3 - #6 Decoding, eos_token for req #1 must be -inf for _ in range(3): @@ -145,3 +145,68 @@ def test_mintokens_logits_processor(): ) assert torch.allclose(logits, out_logits) + +@pytest.mark.cpu +@pytest.mark.worker +def test_logitbias_logits_processor(): + device = torch.device('cpu') + + dummy_config = DummyVllmConfig() + lp = SpyreLogitBiasLogitsProcessor(dummy_config, device, False) + + batch_update_builder = BatchUpdateBuilder() + + batch_output_tokens = [[], [], []] + + # Step #0 Prefill req_id #0 (no logits bias) + logits = generate_logits(1) + out_logits = prefill(SamplingParams(), + batch_update_builder, + lp, + logits, + batch_output_tokens[0], + req_idx=0, + num_reqs=1) + + assert torch.allclose(logits, out_logits) + + # Step #1 Prefill req_id #1 (with logits bias) + params = SamplingParams(logit_bias={1: 100, 7: -100}) + + logits = generate_logits(1) + out_logits = prefill(params, + batch_update_builder, + lp, + logits, + batch_output_tokens[1], + req_idx=1, + num_reqs=2) + + assert out_logits[0][1].item() >= 100 + assert out_logits[0][7].item() < 0 + + # Step #2 Prefill req_id #2 (no logits bias) + logits = generate_logits(1) + out_logits = prefill(SamplingParams(), + batch_update_builder, + lp, + logits, + batch_output_tokens[2], + req_idx=2, + num_reqs=3) + + assert torch.allclose(logits, out_logits) + + # Step #4, decoding, keep applying lotis bias + logits = generate_logits(3) + out_logits = decode( + batch_update_builder, + lp, + logits, + batch_output_tokens, + ) + + assert torch.allclose(logits[0], out_logits[0]) + assert torch.allclose(logits[2], out_logits[2]) + assert out_logits[1][1].item() >= 100 + assert out_logits[1][7].item() < 0 diff --git a/vllm_spyre/v1/sample/spyre_logits_processor.py b/vllm_spyre/v1/sample/spyre_logits_processor.py index 206bfe91..65c100d6 100644 --- a/vllm_spyre/v1/sample/spyre_logits_processor.py +++ b/vllm_spyre/v1/sample/spyre_logits_processor.py @@ -42,17 +42,42 @@ def build_logitsprocs_for_cb( return LogitsProcessors() custom_logitsprocs_classes = _load_custom_logitsprocs(custom_logitsprocs) - return LogitsProcessors( - LogitsProcessorWrapper(logit_processor, + # LogitBiasLogitsProcessor, + # LogitsProcessor, + # MinPLogitsProcessor, + # MinTokensLogitsProcessor, + return LogitsProcessors( itertools.chain( + [SpyreLogitBiasLogitsProcessor(vllm_config, + device, + is_pin_memory), + LogitsProcessorWrapper(MinPLogitsProcessor, + vllm_config, + device, + is_pin_memory, + batch_size), + SpyreMinTokensLogitsProcessor(vllm_config, + device, + is_pin_memory), + ], + [LogitsProcessorWrapper(logit_processor, vllm_config, device, is_pin_memory, batch_size) \ - for logit_processor in itertools.chain( - BUILTIN_LOGITS_PROCESSORS, - custom_logitsprocs_classes - ) - ) + for logit_processor in custom_logitsprocs_classes] + )) + # return LogitsProcessors( + + # LogitsProcessorWrapper(logit_processor, + # vllm_config, + # device, + # is_pin_memory, + # batch_size) \ + # for logit_processor in itertools.chain( + # BUILTIN_LOGITS_PROCESSORS, + # custom_logitsprocs_classes + # ) + # ) class SpyreLogitsProcessor: @@ -155,50 +180,49 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: class SpyreLogitBiasLogitsProcessor(LogitBiasLogitsProcessor, SpyreLogitsProcessor): - def __init__(self, _, device: torch.device, is_pin_memory: bool): - self.device = device - self.pin_memory = is_pin_memory - self.biases: dict[int, dict[int, float]] = {} + def __init__(self, config : VllmConfig, device: torch.device, is_pin_memory: bool): + super().__init__(config, device, is_pin_memory) - self.bias_tensor: torch.Tensor = torch.tensor(()) - self.logits_slice = (self._device_tensor([], torch.int32), - self._device_tensor([], torch.int32)) + self._is_prefill : bool = False + self._prefill_slice : Optional[tuple[torch.Tensor, torch.Tensor]] \ + = None + self._prefill_bias : torch.Tensor = torch.tensor(()) - def is_argmax_invariant(self) -> bool: - """Logit bias can rebalance token probabilities and change the - outcome of argmax in greedy sampling.""" - return False - def update_state(self, batch_update: Optional[BatchUpdate]): - needs_update = process_dict_updates( - self.biases, batch_update, - lambda params, _, __: params.logit_bias or None) - - # Update tensors if needed. - if needs_update: - reqs: list[int] = [] - tok_ids: list[int] = [] - biases: list[float] = [] - for req, lb in self.biases.items(): - reqs.extend([req] * len(lb)) + def set_prefill_index(self, idx: int) -> None: + + reqs: list[int] = [] + tok_ids: list[int] = [] + biases: list[float] = [] + for req, lb in self.biases.items(): + if req == idx: + # NOTE: always request 0 for prefill + # prefill will only have logits for a single request + reqs.extend([0] * len(lb)) tok_ids.extend(lb.keys()) biases.extend(lb.values()) - self.bias_tensor = self._device_tensor(biases, torch.float32) - self.logits_slice = (self._device_tensor(reqs, torch.int32), - self._device_tensor(tok_ids, torch.int32)) + if biases: + self._prefill_slice = (self._device_tensor(reqs, torch.int32), + self._device_tensor(tok_ids, torch.int32)) + self._prefill_bias = self._device_tensor(biases, torch.float32) + self._is_prefill = True - def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor: - return (torch.tensor(data, - device="cpu", - dtype=dtype, - pin_memory=self.pin_memory).to(device=self.device, - non_blocking=True)) def apply(self, logits: torch.Tensor) -> torch.Tensor: - if self.biases: - logits[self.logits_slice] += self.bias_tensor - return logits + if self._prefill_slice is not None: + logits[self._prefill_slice] += self._prefill_bias + self._prefill_slice = None + self._is_prefill = False + return logits + elif self._is_prefill: + # It is prefill but we do not need to do anything + # for the prefill request, just return logits to + # avoid slice the logits with batch_size = 1 + self._is_prefill = False + return logits + + return super().apply(logits) class SpyreMinTokensLogitsProcessor(MinTokensLogitsProcessor, @@ -218,7 +242,7 @@ def set_prefill_index(self, idx: int) -> None: for req, (_, _, stop_tok_ids) in self.min_toks.items(): if req == idx: # NOTE: always request 0 for prefill - # logits will only have logits for a single request + # prefill will only have logits for a single request reqs.extend([0] * len(stop_tok_ids)) tok_ids.extend(stop_tok_ids) diff --git a/vllm_spyre/v1/worker/spyre_model_runner.py b/vllm_spyre/v1/worker/spyre_model_runner.py index 3866279e..5684a0da 100644 --- a/vllm_spyre/v1/worker/spyre_model_runner.py +++ b/vllm_spyre/v1/worker/spyre_model_runner.py @@ -1017,14 +1017,14 @@ def _prepare_prompt( prefill_index = self.input_batch.add_request(req_state) self.prefill_batch.add_request(req_state) - # set prefill index for logits processor - for logitsproc in self.input_batch.spyre_logitsprocs: - logitsproc.set_prefill_index(prefill_index) - # Refresh sampling metadata after all request are added to the batch self.input_batch.refresh_metadata() self.prefill_batch.refresh_metadata() + # set prefill index for logits processor + for logitsproc in self.input_batch.spyre_logitsprocs: + logitsproc.set_prefill_index(prefill_index) + self.model.indices = torch.ones(1, dtype=torch.bool, device='cpu') slot_mapping = torch.tensor([slots], dtype=torch.int64) prompt_token_ids_tensor = torch.tensor(prompt_token_ids, @@ -1370,12 +1370,6 @@ def build_input_batch(self) -> SamplingInputBatch: is_pooling_model=False, custom_logitsprocs=custom_logitsprocs, batch_size=batch_size) - # logits_processors = \ - # build_logitsprocs(vllm_config=self.vllm_config, - # device=self.device, - # is_pin_memory=self.pin_memory, - # is_pooling_model=False, - # custom_logitsprocs=custom_logitsprocs) return SamplingInputBatch( max_num_reqs=batch_size, From 07e2f907342933e7c56b361f2dfd382bfd04ede9 Mon Sep 17 00:00:00 2001 From: Wallas Santos Date: Tue, 14 Oct 2025 14:49:41 -0300 Subject: [PATCH 05/16] feat: tests for min p Signed-off-by: Wallas Santos --- tests/e2e/test_sampling_params.py | 13 +- tests/utils/test_spyre_logits_processor.py | 118 ++++++++++++++++-- .../v1/sample/spyre_logits_processor.py | 71 +++++------ 3 files changed, 143 insertions(+), 59 deletions(-) diff --git a/tests/e2e/test_sampling_params.py b/tests/e2e/test_sampling_params.py index fc9b6d02..5d141411 100644 --- a/tests/e2e/test_sampling_params.py +++ b/tests/e2e/test_sampling_params.py @@ -209,9 +209,6 @@ def test_spyre_batch1_top_k(model: ModelInfo, backend, monkeypatch, assert token_div1 < token_div2 -# def test_spyre_batch1_min_tokens(model: ModelInfo, backend, monkeypatch, -# use_llm_cache, max_model_len, max_num_seqs, -# warmup_shapes, cb: int): def test_spyre_batch1_logit_bias(model: ModelInfo, backend, monkeypatch, use_llm_cache, warmup_shapes, max_model_len, max_num_seqs, cb: int): @@ -315,15 +312,17 @@ def test_spyre_batch1_ignore_eos(model: ModelInfo, backend, monkeypatch, def test_spyre_batch1_min_p(model: ModelInfo, backend, monkeypatch, - use_llm_cache, warmup_shapes): + use_llm_cache, max_model_len, max_num_seqs, + warmup_shapes, cb: int): spyre_model = get_cached_llm( model=model, - max_model_len=128, + max_model_len=max_model_len, + max_num_seqs=max_num_seqs, tensor_parallel_size=1, backend=backend, monkeypatch=monkeypatch, - warmup_shapes=warmup_shapes, - ) + warmup_shapes=warmup_shapes if cb == 0 else None, + use_cb=cb == 1) prompt = "The opposite of black is" params1 = SamplingParams(min_p=0.5, temperature=1, max_tokens=5) params2 = SamplingParams(temperature=1, max_tokens=5) diff --git a/tests/utils/test_spyre_logits_processor.py b/tests/utils/test_spyre_logits_processor.py index d0eeeb6e..04eb6724 100644 --- a/tests/utils/test_spyre_logits_processor.py +++ b/tests/utils/test_spyre_logits_processor.py @@ -3,10 +3,12 @@ import pytest import torch from vllm.sampling_params import SamplingParams -from vllm.v1.sample.logits_processor import BatchUpdateBuilder +from vllm.v1.sample.logits_processor import (BatchUpdateBuilder, + MinPLogitsProcessor) from vllm_spyre.v1.sample.spyre_logits_processor import ( - SpyreLogitsProcessor, SpyreMinTokensLogitsProcessor, SpyreLogitBiasLogitsProcessor) + SpyreLogitBiasLogitsProcessor, SpyreLogitsProcessor, + SpyreMinPLogitsProcessor, SpyreMinTokensLogitsProcessor) EOS_TOKEN = 3 VOCAB_SIZE = 8 @@ -20,7 +22,7 @@ def __init__(self): class DummySchedulerConfig: - def __init__(self, max_num_seqs=1): + def __init__(self, max_num_seqs=4): self.max_num_seqs = max_num_seqs @@ -40,7 +42,9 @@ def prefill(params: SamplingParams, batch_update_builder: BatchUpdateBuilder, batch_update = batch_update_builder.get_and_reset(num_reqs) lp.update_state(batch_update) - lp.set_prefill_index(req_idx) + if isinstance(lp, SpyreLogitsProcessor): + lp.set_prefill_index(req_idx) + out_logits = lp.apply(logits.clone()) ouput_tokens.append(0) # just append a random token @@ -93,7 +97,7 @@ def test_mintokens_logits_processor(): req_idx=0, num_reqs=1) - assert torch.allclose(logits, out_logits) + assert torch.allclose(logits, out_logits) # Step #1 Prefill req_id #1 (with min tokens) params = SamplingParams(min_tokens=4) @@ -119,7 +123,7 @@ def test_mintokens_logits_processor(): req_idx=2, num_reqs=3) - assert torch.allclose(logits, out_logits) + assert torch.allclose(logits, out_logits) # Step #3 - #6 Decoding, eos_token for req #1 must be -inf for _ in range(3): @@ -146,6 +150,7 @@ def test_mintokens_logits_processor(): assert torch.allclose(logits, out_logits) + @pytest.mark.cpu @pytest.mark.worker def test_logitbias_logits_processor(): @@ -168,7 +173,7 @@ def test_logitbias_logits_processor(): req_idx=0, num_reqs=1) - assert torch.allclose(logits, out_logits) + assert torch.allclose(logits, out_logits) # Step #1 Prefill req_id #1 (with logits bias) params = SamplingParams(logit_bias={1: 100, 7: -100}) @@ -195,9 +200,9 @@ def test_logitbias_logits_processor(): req_idx=2, num_reqs=3) - assert torch.allclose(logits, out_logits) + assert torch.allclose(logits, out_logits) - # Step #4, decoding, keep applying lotis bias + # Step #4, decoding, keep applying logits bias logits = generate_logits(3) out_logits = decode( batch_update_builder, @@ -210,3 +215,98 @@ def test_logitbias_logits_processor(): assert torch.allclose(logits[2], out_logits[2]) assert out_logits[1][1].item() >= 100 assert out_logits[1][7].item() < 0 + + +@pytest.mark.cpu +@pytest.mark.worker +def test_minp_logits_processor(): + device = torch.device('cpu') + + dummy_config = DummyVllmConfig() + lp = SpyreMinPLogitsProcessor(dummy_config, device, False) + # vllm logits processor + vllm_lp = MinPLogitsProcessor(dummy_config, device, False) + + batch_update_builder = BatchUpdateBuilder() + + batch_output_tokens = [[], [], []] + + # Step #0 Prefill req_id #0 (no min p) + logits = generate_logits(1) + out_logits = prefill(SamplingParams(), + batch_update_builder, + lp, + logits, + batch_output_tokens[0], + req_idx=0, + num_reqs=1) + vllm_logits = prefill(SamplingParams(), + batch_update_builder, + vllm_lp, + logits, + batch_output_tokens[0], + req_idx=0, + num_reqs=1) + + assert torch.allclose(vllm_logits, out_logits) + + # Step #1 Prefill req_id #1 (with logits bias) + params = SamplingParams(min_p=0.15) + + logits = generate_logits(1) + out_logits = prefill(params, + batch_update_builder, + lp, + logits, + batch_output_tokens[1], + req_idx=1, + num_reqs=2) + logits = generate_logits(2) + vllm_logits = prefill(params, + batch_update_builder, + vllm_lp, + logits, + batch_output_tokens[1], + req_idx=1, + num_reqs=2) + + assert torch.allclose(vllm_logits[1], out_logits[0]) + + # Step #2 Prefill req_id #2 (no min p) + logits = generate_logits(1) + out_logits = prefill(SamplingParams(), + batch_update_builder, + lp, + logits, + batch_output_tokens[2], + req_idx=2, + num_reqs=3) + + logits = generate_logits(3) + vllm_logits = prefill(SamplingParams(), + batch_update_builder, + vllm_lp, + logits, + batch_output_tokens[2], + req_idx=2, + num_reqs=3) + + assert torch.allclose(logits, out_logits) + + # Step #4, decoding, keep applying min p + logits = generate_logits(3) + out_logits = decode( + batch_update_builder, + lp, + logits, + batch_output_tokens, + ) + + vllm_logits = decode( + batch_update_builder, + vllm_lp, + logits, + batch_output_tokens, + ) + + assert torch.allclose(out_logits, vllm_logits) diff --git a/vllm_spyre/v1/sample/spyre_logits_processor.py b/vllm_spyre/v1/sample/spyre_logits_processor.py index 65c100d6..4bf03fe6 100644 --- a/vllm_spyre/v1/sample/spyre_logits_processor.py +++ b/vllm_spyre/v1/sample/spyre_logits_processor.py @@ -10,8 +10,7 @@ LogitsProcessor, MinPLogitsProcessor, MinTokensLogitsProcessor, - _load_custom_logitsprocs, - process_dict_updates) + _load_custom_logitsprocs) from vllm.v1.sample.logits_processor.state import LogitsProcessors logger = init_logger(__name__) @@ -42,19 +41,13 @@ def build_logitsprocs_for_cb( return LogitsProcessors() custom_logitsprocs_classes = _load_custom_logitsprocs(custom_logitsprocs) - # LogitBiasLogitsProcessor, - # LogitsProcessor, - # MinPLogitsProcessor, - # MinTokensLogitsProcessor, return LogitsProcessors( itertools.chain( [SpyreLogitBiasLogitsProcessor(vllm_config, device, is_pin_memory), - LogitsProcessorWrapper(MinPLogitsProcessor, - vllm_config, + SpyreMinPLogitsProcessor(vllm_config, device, - is_pin_memory, - batch_size), + is_pin_memory), SpyreMinTokensLogitsProcessor(vllm_config, device, is_pin_memory), @@ -66,18 +59,6 @@ def build_logitsprocs_for_cb( batch_size) \ for logit_processor in custom_logitsprocs_classes] )) - # return LogitsProcessors( - - # LogitsProcessorWrapper(logit_processor, - # vllm_config, - # device, - # is_pin_memory, - # batch_size) \ - # for logit_processor in itertools.chain( - # BUILTIN_LOGITS_PROCESSORS, - # custom_logitsprocs_classes - # ) - # ) class SpyreLogitsProcessor: @@ -152,42 +133,47 @@ class SpyreMinPLogitsProcessor(MinPLogitsProcessor, SpyreLogitsProcessor): def __init__(self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool): super().__init__(vllm_config, device, is_pin_memory) - self.prefill_index: Optional[int] = None + self._prefill_index: Optional[int] = None + + def set_prefill_index(self, idx: int) -> None: + self._prefill_index = idx def apply(self, logits: torch.Tensor) -> torch.Tensor: + if self._prefill_index is None: + return super().apply(logits) if not self.min_p_count: return logits - # if self.prefill_index is not None: - # pass + # Convert logits to probability distribution + probability_values = torch.nn.functional.softmax(logits, dim=-1) + # Calculate maximum probabilities per sequence + max_probabilities = torch.amax(probability_values, + dim=-1, + keepdim=True) + # Adjust min_p + adjusted_min_p = max_probabilities.mul_( + self.min_p[self._prefill_index].unsqueeze(0)) + # Identify valid tokens using threshold comparison + invalid_token_mask = probability_values < adjusted_min_p + # Apply mask using boolean indexing + logits[invalid_token_mask] = -float('inf') + self._prefill_index = None - # # Convert logits to probability distribution - # probability_values = torch.nn.functional.softmax(logits, dim=-1) - # # Calculate maximum probabilities per sequence - # max_probabilities = torch.amax(probability_values, - # dim=-1, - # keepdim=True) - # # Adjust min_p - # adjusted_min_p = max_probabilities.mul_(self.min_p) - # # Identify valid tokens using threshold comparison - # invalid_token_mask = probability_values < adjusted_min_p - # # Apply mask using boolean indexing - # logits[invalid_token_mask] = -float('inf') - # return logits + return logits class SpyreLogitBiasLogitsProcessor(LogitBiasLogitsProcessor, SpyreLogitsProcessor): - def __init__(self, config : VllmConfig, device: torch.device, is_pin_memory: bool): + def __init__(self, config: VllmConfig, device: torch.device, + is_pin_memory: bool): super().__init__(config, device, is_pin_memory) - self._is_prefill : bool = False + self._is_prefill: bool = False self._prefill_slice : Optional[tuple[torch.Tensor, torch.Tensor]] \ = None - self._prefill_bias : torch.Tensor = torch.tensor(()) - + self._prefill_bias: torch.Tensor = torch.tensor(()) def set_prefill_index(self, idx: int) -> None: @@ -208,7 +194,6 @@ def set_prefill_index(self, idx: int) -> None: self._prefill_bias = self._device_tensor(biases, torch.float32) self._is_prefill = True - def apply(self, logits: torch.Tensor) -> torch.Tensor: if self._prefill_slice is not None: logits[self._prefill_slice] += self._prefill_bias From 8f3c41b077e704ab130cd666f9d0db283cc22075 Mon Sep 17 00:00:00 2001 From: Wallas Santos Date: Tue, 14 Oct 2025 15:04:58 -0300 Subject: [PATCH 06/16] style: fix linting Signed-off-by: Wallas Santos --- vllm_spyre/v1/sample/spyre_logits_processor.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/vllm_spyre/v1/sample/spyre_logits_processor.py b/vllm_spyre/v1/sample/spyre_logits_processor.py index 4bf03fe6..e55c004c 100644 --- a/vllm_spyre/v1/sample/spyre_logits_processor.py +++ b/vllm_spyre/v1/sample/spyre_logits_processor.py @@ -3,14 +3,10 @@ import torch from vllm.logger import init_logger -from vllm.v1.sample.logits_processor import (BUILTIN_LOGITS_PROCESSORS, - STR_POOLING_REJECTS_LOGITSPROCS, - BatchUpdate, - LogitBiasLogitsProcessor, - LogitsProcessor, - MinPLogitsProcessor, - MinTokensLogitsProcessor, - _load_custom_logitsprocs) +from vllm.v1.sample.logits_processor import ( + BUILTIN_LOGITS_PROCESSORS, STR_POOLING_REJECTS_LOGITSPROCS, BatchUpdate, + LogitBiasLogitsProcessor, LogitsProcessor, MinPLogitsProcessor, + MinTokensLogitsProcessor, _load_custom_logitsprocs) from vllm.v1.sample.logits_processor.state import LogitsProcessors logger = init_logger(__name__) From db5cae28c02b468793d3219dd5fdf4cd89e2e31b Mon Sep 17 00:00:00 2001 From: Wallas Santos Date: Tue, 14 Oct 2025 15:07:32 -0300 Subject: [PATCH 07/16] style: fix linting Signed-off-by: Wallas Santos --- vllm_spyre/v1/sample/spyre_logits_processor.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/vllm_spyre/v1/sample/spyre_logits_processor.py b/vllm_spyre/v1/sample/spyre_logits_processor.py index e55c004c..4bf03fe6 100644 --- a/vllm_spyre/v1/sample/spyre_logits_processor.py +++ b/vllm_spyre/v1/sample/spyre_logits_processor.py @@ -3,10 +3,14 @@ import torch from vllm.logger import init_logger -from vllm.v1.sample.logits_processor import ( - BUILTIN_LOGITS_PROCESSORS, STR_POOLING_REJECTS_LOGITSPROCS, BatchUpdate, - LogitBiasLogitsProcessor, LogitsProcessor, MinPLogitsProcessor, - MinTokensLogitsProcessor, _load_custom_logitsprocs) +from vllm.v1.sample.logits_processor import (BUILTIN_LOGITS_PROCESSORS, + STR_POOLING_REJECTS_LOGITSPROCS, + BatchUpdate, + LogitBiasLogitsProcessor, + LogitsProcessor, + MinPLogitsProcessor, + MinTokensLogitsProcessor, + _load_custom_logitsprocs) from vllm.v1.sample.logits_processor.state import LogitsProcessors logger = init_logger(__name__) From 29a91da031b882827a82da6f008685b473caf28d Mon Sep 17 00:00:00 2001 From: Wallas Santos Date: Wed, 15 Oct 2025 11:39:54 -0300 Subject: [PATCH 08/16] style: bypass yapf Signed-off-by: Wallas Santos --- vllm_spyre/v1/sample/spyre_logits_processor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm_spyre/v1/sample/spyre_logits_processor.py b/vllm_spyre/v1/sample/spyre_logits_processor.py index 4bf03fe6..d57d9b5a 100644 --- a/vllm_spyre/v1/sample/spyre_logits_processor.py +++ b/vllm_spyre/v1/sample/spyre_logits_processor.py @@ -3,6 +3,7 @@ import torch from vllm.logger import init_logger +# yapf: disable from vllm.v1.sample.logits_processor import (BUILTIN_LOGITS_PROCESSORS, STR_POOLING_REJECTS_LOGITSPROCS, BatchUpdate, @@ -11,6 +12,7 @@ MinPLogitsProcessor, MinTokensLogitsProcessor, _load_custom_logitsprocs) +# yapf: enable from vllm.v1.sample.logits_processor.state import LogitsProcessors logger = init_logger(__name__) From b99d3a5e513a44919462e2ce11353e12eba31c94 Mon Sep 17 00:00:00 2001 From: Wallas Santos Date: Wed, 15 Oct 2025 12:19:00 -0300 Subject: [PATCH 09/16] feat: minor optimization on tests Signed-off-by: Wallas Santos --- tests/e2e/test_sampling_params.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/e2e/test_sampling_params.py b/tests/e2e/test_sampling_params.py index 5d141411..dae9b540 100644 --- a/tests/e2e/test_sampling_params.py +++ b/tests/e2e/test_sampling_params.py @@ -164,8 +164,11 @@ def test_spyre_batch1_n_generations(model: ModelInfo, backend, monkeypatch, def token_diversity(spyre_model, prompt, params, n_experiments): tokens = [] - for i in range(n_experiments): - output = spyre_model.generate(prompt, params)[0] + + outputs = spyre_model.generate([prompt] * n_experiments, + params, + use_tqdm=False) + for output in outputs: tokens.extend(output.outputs[0].token_ids) return len(set(tokens)) From ac427c27d2396e87191b466ec9ef31f97aebbe78 Mon Sep 17 00:00:00 2001 From: Wallas Santos Date: Thu, 16 Oct 2025 12:08:48 -0300 Subject: [PATCH 10/16] feat: new PrefillHelperLogitsProcessor to remove code deduplication Signed-off-by: Wallas Santos --- tests/utils/test_spyre_logits_processor.py | 2 +- .../v1/sample/spyre_logits_processor.py | 209 ++++++++---------- vllm_spyre/v1/worker/spyre_model_runner.py | 2 +- 3 files changed, 95 insertions(+), 118 deletions(-) diff --git a/tests/utils/test_spyre_logits_processor.py b/tests/utils/test_spyre_logits_processor.py index 04eb6724..b7a6fec8 100644 --- a/tests/utils/test_spyre_logits_processor.py +++ b/tests/utils/test_spyre_logits_processor.py @@ -43,7 +43,7 @@ def prefill(params: SamplingParams, batch_update_builder: BatchUpdateBuilder, lp.update_state(batch_update) if isinstance(lp, SpyreLogitsProcessor): - lp.set_prefill_index(req_idx) + lp.set_prefill(req_idx) out_logits = lp.apply(logits.clone()) ouput_tokens.append(0) # just append a random token diff --git a/vllm_spyre/v1/sample/spyre_logits_processor.py b/vllm_spyre/v1/sample/spyre_logits_processor.py index d57d9b5a..a4afacc6 100644 --- a/vllm_spyre/v1/sample/spyre_logits_processor.py +++ b/vllm_spyre/v1/sample/spyre_logits_processor.py @@ -13,6 +13,7 @@ MinTokensLogitsProcessor, _load_custom_logitsprocs) # yapf: enable +from vllm.v1.sample.logits_processor.interface import AddedRequest from vllm.v1.sample.logits_processor.state import LogitsProcessors logger = init_logger(__name__) @@ -24,7 +25,9 @@ SamplingParams = None VllmConfig = None -assert len(BUILTIN_LOGITS_PROCESSORS) == 3 +SPYRE_BUILTIN_LOGITS_PROCESSORS = [ + MinPLogitsProcessor, MinTokensLogitsProcessor, LogitBiasLogitsProcessor +] def build_logitsprocs_for_cb( @@ -35,6 +38,14 @@ def build_logitsprocs_for_cb( batch_size: int, custom_logitsprocs: Sequence[Union[str, type[LogitsProcessor]]] = (), ) -> LogitsProcessors: + + if len(BUILTIN_LOGITS_PROCESSORS) > 3: + logger.warning( + "There are %d logits processors, which is unexpected " + "for this vllm-spyre version. Consider upgrade " + "vllm-spyre or open an issue to investigate this", + len(BUILTIN_LOGITS_PROCESSORS)) + if is_pooling_model: if custom_logitsprocs: raise ValueError(STR_POOLING_REJECTS_LOGITSPROCS) @@ -43,6 +54,12 @@ def build_logitsprocs_for_cb( return LogitsProcessors() custom_logitsprocs_classes = _load_custom_logitsprocs(custom_logitsprocs) + # Collect builtin LPs to fallback to the wrapper + builtin_logitsprocs = [lp for lp in BUILTIN_LOGITS_PROCESSORS \ + if lp not in SPYRE_BUILTIN_LOGITS_PROCESSORS] + + logitprocs_classes = custom_logitsprocs_classes + builtin_logitsprocs + return LogitsProcessors( itertools.chain( [SpyreLogitBiasLogitsProcessor(vllm_config, device, @@ -59,18 +76,78 @@ def build_logitsprocs_for_cb( device, is_pin_memory, batch_size) \ - for logit_processor in custom_logitsprocs_classes] + for logit_processor in logitprocs_classes] )) class SpyreLogitsProcessor: - def set_prefill_index(self, idx: int) -> None: + def set_prefill(self, idx: int) -> None: raise NotImplementedError +class PrefillHelperLogitsProcessor(LogitsProcessor, SpyreLogitsProcessor): + """ + Logits processor (LP) that separates two instances of a concrete LPS: + one for the prefill, and other for the batch. This class only works if + the state of the LP is independent between prefill and decoding. for + example this class is not suitable for the golden token injector LP. + """ + + def __init__(self, config: VllmConfig, device: torch.device, + is_pin_memory: bool, logit_processor: LogitsProcessor): + self._prefill_lp : LogitsProcessor = \ + logit_processor(config, device, is_pin_memory) + self._batch_lp : LogitsProcessor = \ + logit_processor(config, device, is_pin_memory) + + self._is_prefill: bool = False + + # This dictionary stores the sampling parameters of `update_state` so + # we can get when we call `set_prefill` to proper setup the prefill_lp. + self._params: dict[int, AddedRequest] = {} + + def is_argmax_invariant(self) -> bool: + """Never impacts greedy sampling""" + return self._batch_lp.is_argmax_invariant() + + def update_state(self, batch_update: BatchUpdate | None): + + if batch_update: + for added_request in batch_update.added: + self._params[added_request[0]] = added_request + + # Always pass to the batch LP + self._batch_lp.update_state(batch_update) + + def set_prefill(self, idx: int) -> None: + + _, params, prompt_tok_ids, out_tok_ids = self._params[idx] + self._prefill_lp.update_state( + BatchUpdate( + batch_size=1, + removed=[], + moved=[], + added=[(0, params, prompt_tok_ids, out_tok_ids)], + )) + self._params.pop(idx) + self._is_prefill = True + + def apply(self, logits: torch.Tensor) -> torch.Tensor: + if self._is_prefill: + logits = self._prefill_lp.apply(logits) + + # Clean the prefill LP + self._is_prefill = False + self._prefill_lp.update_state( + BatchUpdate(batch_size=1, removed=[0], moved=[], added=[])) + return logits + + return self._batch_lp.apply(logits) + + class LogitsProcessorWrapper(LogitsProcessor, SpyreLogitsProcessor): - """Logit processor to inject expected token during generation for tests""" + """Logit processor to isolate logits processors to run individually""" def __init__(self, logit_processor: LogitsProcessor, vllm_config: VllmConfig, device: torch.device, @@ -126,129 +203,29 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: return logits - def set_prefill_index(self, idx: int) -> None: + def set_prefill(self, idx: int) -> None: self._prefill_index = idx -class SpyreMinPLogitsProcessor(MinPLogitsProcessor, SpyreLogitsProcessor): +class SpyreMinPLogitsProcessor(PrefillHelperLogitsProcessor): - def __init__(self, vllm_config: "VllmConfig", device: torch.device, + def __init__(self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool): - super().__init__(vllm_config, device, is_pin_memory) - self._prefill_index: Optional[int] = None - - def set_prefill_index(self, idx: int) -> None: - self._prefill_index = idx + super().__init__(vllm_config, device, is_pin_memory, + MinPLogitsProcessor) - def apply(self, logits: torch.Tensor) -> torch.Tensor: - if self._prefill_index is None: - return super().apply(logits) - - if not self.min_p_count: - return logits - - # Convert logits to probability distribution - probability_values = torch.nn.functional.softmax(logits, dim=-1) - # Calculate maximum probabilities per sequence - max_probabilities = torch.amax(probability_values, - dim=-1, - keepdim=True) - # Adjust min_p - adjusted_min_p = max_probabilities.mul_( - self.min_p[self._prefill_index].unsqueeze(0)) - # Identify valid tokens using threshold comparison - invalid_token_mask = probability_values < adjusted_min_p - # Apply mask using boolean indexing - logits[invalid_token_mask] = -float('inf') - self._prefill_index = None - - return logits +class SpyreLogitBiasLogitsProcessor(PrefillHelperLogitsProcessor): -class SpyreLogitBiasLogitsProcessor(LogitBiasLogitsProcessor, - SpyreLogitsProcessor): - - def __init__(self, config: VllmConfig, device: torch.device, + def __init__(self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool): - super().__init__(config, device, is_pin_memory) + super().__init__(vllm_config, device, is_pin_memory, + LogitBiasLogitsProcessor) - self._is_prefill: bool = False - self._prefill_slice : Optional[tuple[torch.Tensor, torch.Tensor]] \ - = None - self._prefill_bias: torch.Tensor = torch.tensor(()) - - def set_prefill_index(self, idx: int) -> None: - - reqs: list[int] = [] - tok_ids: list[int] = [] - biases: list[float] = [] - for req, lb in self.biases.items(): - if req == idx: - # NOTE: always request 0 for prefill - # prefill will only have logits for a single request - reqs.extend([0] * len(lb)) - tok_ids.extend(lb.keys()) - biases.extend(lb.values()) - - if biases: - self._prefill_slice = (self._device_tensor(reqs, torch.int32), - self._device_tensor(tok_ids, torch.int32)) - self._prefill_bias = self._device_tensor(biases, torch.float32) - self._is_prefill = True - def apply(self, logits: torch.Tensor) -> torch.Tensor: - if self._prefill_slice is not None: - logits[self._prefill_slice] += self._prefill_bias - self._prefill_slice = None - self._is_prefill = False - return logits - elif self._is_prefill: - # It is prefill but we do not need to do anything - # for the prefill request, just return logits to - # avoid slice the logits with batch_size = 1 - self._is_prefill = False - return logits - - return super().apply(logits) - - -class SpyreMinTokensLogitsProcessor(MinTokensLogitsProcessor, - SpyreLogitsProcessor): +class SpyreMinTokensLogitsProcessor(PrefillHelperLogitsProcessor): def __init__(self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool): - super().__init__(vllm_config, device, is_pin_memory) - self._prefill_slice : Optional[tuple[torch.Tensor, torch.Tensor]] \ - = None - self._is_prefill: bool = False - - def set_prefill_index(self, idx: int) -> None: - - reqs: list[int] = [] - tok_ids: list[int] = [] - for req, (_, _, stop_tok_ids) in self.min_toks.items(): - if req == idx: - # NOTE: always request 0 for prefill - # prefill will only have logits for a single request - reqs.extend([0] * len(stop_tok_ids)) - tok_ids.extend(stop_tok_ids) - - if reqs and tok_ids: - self._prefill_slice = (self._device_tensor(reqs, torch.int32), - self._device_tensor(tok_ids, torch.int32)) - self._is_prefill = True - - def apply(self, logits: torch.Tensor) -> torch.Tensor: - if self._prefill_slice is not None: - logits[self._prefill_slice] = -float("inf") - self._prefill_slice = None - self._is_prefill = False - return logits - elif self._is_prefill: - # It is prefill but we do not need to do anything - # for the prefill request, just return logits to - # avoid slice the logits with batch_size = 1 - self._is_prefill = False - return logits - - return super().apply(logits) + super().__init__(vllm_config, device, is_pin_memory, + MinTokensLogitsProcessor) diff --git a/vllm_spyre/v1/worker/spyre_model_runner.py b/vllm_spyre/v1/worker/spyre_model_runner.py index f115c2e5..72edf4b3 100644 --- a/vllm_spyre/v1/worker/spyre_model_runner.py +++ b/vllm_spyre/v1/worker/spyre_model_runner.py @@ -1020,7 +1020,7 @@ def _prepare_prompt( # set prefill index for logits processor for logitsproc in self.input_batch.spyre_logitsprocs: - logitsproc.set_prefill_index(prefill_index) + logitsproc.set_prefill(prefill_index) self.model.indices = torch.ones(1, dtype=torch.bool, device='cpu') slot_mapping = torch.tensor([slots], dtype=torch.int64) From 83cb572a616a51c3a02223880985c4bb5a8aba79 Mon Sep 17 00:00:00 2001 From: Wallas Santos Date: Thu, 16 Oct 2025 13:57:16 -0300 Subject: [PATCH 11/16] fix: batch management Signed-off-by: Wallas Santos --- .../v1/sample/spyre_logits_processor.py | 23 ++++++++++++------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/vllm_spyre/v1/sample/spyre_logits_processor.py b/vllm_spyre/v1/sample/spyre_logits_processor.py index a4afacc6..48dedb8b 100644 --- a/vllm_spyre/v1/sample/spyre_logits_processor.py +++ b/vllm_spyre/v1/sample/spyre_logits_processor.py @@ -11,9 +11,9 @@ LogitsProcessor, MinPLogitsProcessor, MinTokensLogitsProcessor, - _load_custom_logitsprocs) + _load_custom_logitsprocs, + process_dict_updates) # yapf: enable -from vllm.v1.sample.logits_processor.interface import AddedRequest from vllm.v1.sample.logits_processor.state import LogitsProcessors logger = init_logger(__name__) @@ -105,24 +105,31 @@ def __init__(self, config: VllmConfig, device: torch.device, # This dictionary stores the sampling parameters of `update_state` so # we can get when we call `set_prefill` to proper setup the prefill_lp. - self._params: dict[int, AddedRequest] = {} + self._params: dict[int, tuple[SamplingParams, list[int], + list[int]]] = {} def is_argmax_invariant(self) -> bool: """Never impacts greedy sampling""" return self._batch_lp.is_argmax_invariant() + @staticmethod + def update_batch_params( + params: SamplingParams, prompt_tok_ids: list[int] | None, + output_tok_ids: list[int] + ) -> tuple[Sequence[int], set[int]] | None: + return params, prompt_tok_ids, prompt_tok_ids + def update_state(self, batch_update: BatchUpdate | None): - if batch_update: - for added_request in batch_update.added: - self._params[added_request[0]] = added_request + process_dict_updates(self._params, batch_update, + self.update_batch_params) # Always pass to the batch LP self._batch_lp.update_state(batch_update) def set_prefill(self, idx: int) -> None: - _, params, prompt_tok_ids, out_tok_ids = self._params[idx] + params, prompt_tok_ids, out_tok_ids = self._params[idx] self._prefill_lp.update_state( BatchUpdate( batch_size=1, @@ -130,7 +137,7 @@ def set_prefill(self, idx: int) -> None: moved=[], added=[(0, params, prompt_tok_ids, out_tok_ids)], )) - self._params.pop(idx) + # self._params.pop(idx) self._is_prefill = True def apply(self, logits: torch.Tensor) -> torch.Tensor: From f770f536585c752860c2c76fd2c0380c58d40e8b Mon Sep 17 00:00:00 2001 From: Wallas Santos Date: Thu, 16 Oct 2025 15:18:52 -0300 Subject: [PATCH 12/16] fix: typing Signed-off-by: Wallas Santos --- vllm_spyre/v1/sample/spyre_logits_processor.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm_spyre/v1/sample/spyre_logits_processor.py b/vllm_spyre/v1/sample/spyre_logits_processor.py index 48dedb8b..0700f7c5 100644 --- a/vllm_spyre/v1/sample/spyre_logits_processor.py +++ b/vllm_spyre/v1/sample/spyre_logits_processor.py @@ -114,10 +114,10 @@ def is_argmax_invariant(self) -> bool: @staticmethod def update_batch_params( - params: SamplingParams, prompt_tok_ids: list[int] | None, - output_tok_ids: list[int] - ) -> tuple[Sequence[int], set[int]] | None: - return params, prompt_tok_ids, prompt_tok_ids + params: SamplingParams, prompt_tok_ids: list[int] | None, + output_tok_ids: list[int] + ) -> tuple[SamplingParams, Sequence[int] | None, Sequence[int]] | None: + return params, prompt_tok_ids, output_tok_ids def update_state(self, batch_update: BatchUpdate | None): From 9a264f23e5d7d9ecbec7885a2ce347d2a84fcc15 Mon Sep 17 00:00:00 2001 From: Wallas Santos Date: Thu, 16 Oct 2025 16:58:24 -0300 Subject: [PATCH 13/16] refact: add GTI to main package Signed-off-by: Wallas Santos --- tests/llm_cache.py | 2 +- .../v1/sample}/golden_token_injector.py | 98 ++++++++++++------- 2 files changed, 65 insertions(+), 35 deletions(-) rename {tests => vllm_spyre/v1/sample}/golden_token_injector.py (70%) diff --git a/tests/llm_cache.py b/tests/llm_cache.py index b545d299..be8a5ae0 100644 --- a/tests/llm_cache.py +++ b/tests/llm_cache.py @@ -5,7 +5,7 @@ from typing import Callable, Generic, Optional, TypeVar import pytest -from golden_token_injector import GoldenTokenInjector +from vllm_spyre.v1.sample.golden_token_injector import GoldenTokenInjector from llm_cache_util import force_engine_shutdown from spyre_util import (DecodeWarmupShapes, ModelInfo, RemoteOpenAIServer, patch_environment) diff --git a/tests/golden_token_injector.py b/vllm_spyre/v1/sample/golden_token_injector.py similarity index 70% rename from tests/golden_token_injector.py rename to vllm_spyre/v1/sample/golden_token_injector.py index 126246ef..bb5cf123 100644 --- a/tests/golden_token_injector.py +++ b/vllm_spyre/v1/sample/golden_token_injector.py @@ -1,13 +1,15 @@ import math from typing import Optional +import json import torch import torch.nn.functional as F from vllm.config import VllmConfig +from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import get_tokenizer -from vllm.v1.sample.logits_processor import (BatchUpdate, LogitsProcessor, - MoveDirectionality) +from vllm.v1.sample.logits_processor import process_dict_updates +logger = init_logger(__name__) class ExpectationState: ''' @@ -26,13 +28,13 @@ class ExpectationState: def __init__(self, expected_token_ids: list[int], - expected_logprobs: list[float], - error_threshold: float, + expected_logprobs: Optional[list[float]], + error_threshold: Optional[float], label: Optional[str] = None): self.token_ids: list[int] = expected_token_ids - self.logprobs: list[float] = expected_logprobs - self.threshold: float = error_threshold + self.logprobs: Optional[list[float]] = expected_logprobs + self.threshold: Optional[float] = error_threshold self.label: Optional[str] = label self.current_token_idx = 0 @@ -54,37 +56,65 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device, def is_argmax_invariant(self) -> bool: """Never impacts greedy sampling""" return False + + + + @staticmethod + def add_req_states( + params: SamplingParams, prompt_tok_ids: list[int] | None, + output_tok_ids: list[int]) -> Optional[ExpectationState]: + + if params.extra_args and ( + injector_dict := + params.extra_args.get("golden_token_injector")): + + # OpenAI API can pass this parameter as string, so + # we will just parse as the expected dict + if isinstance(injector_dict, str): + injector_dict = json.loads(injector_dict) + + return ExpectationState(**injector_dict) + + return None def update_state(self, batch_update: Optional[BatchUpdate]): - # This method keeps the indices consistent of request while the - # persistent batch is changing. - if not batch_update: - return - - # Process added requests. - for index, params, _, _ in batch_update.added: - assert params is not None - if params.extra_args and ( - injector_dict := - params.extra_args.get("golden_token_injector")): - self.req_states[index] = ExpectationState(**injector_dict) - if not self.req_states: - return - - # Process removed requests. - for index in batch_update.removed: - self.req_states.pop(index, None) - - # Process moved requests, unidirectional move (a->b) and swap - # (a<->b) - for adx, bdx, direct in batch_update.moved: - a_val = self.req_states.pop(adx, None) - b_val = self.req_states.pop(bdx, None) - if a_val is not None: - self.req_states[bdx] = a_val - if direct == MoveDirectionality.SWAP and b_val is not None: - self.req_states[adx] = b_val + process_dict_updates(self.req_states, batch_update, self.add_req_states) + # # This method keeps the indices consistent of request while the + # # persistent batch is changing. + # if not batch_update: + # return + + # # Process added requests. + # for index, params, _, _ in batch_update.added: + # assert params is not None + # if params.extra_args and ( + # injector_dict := + # params.extra_args.get("golden_token_injector")): + + # # OpenAI API can pass this parameter as string, so + # # we will just parse as the expected dict + # if isinstance(injector_dict, str): + # injector_dict = json.loads(injector_dict) + + # self.req_states[index] = ExpectationState(**injector_dict) + + # if not self.req_states: + # return + + # # Process removed requests. + # for index in batch_update.removed: + # self.req_states.pop(index, None) + + # # Process moved requests, unidirectional move (a->b) and swap + # # (a<->b) + # for adx, bdx, direct in batch_update.moved: + # a_val = self.req_states.pop(adx, None) + # b_val = self.req_states.pop(bdx, None) + # if a_val is not None: + # self.req_states[bdx] = a_val + # if direct == MoveDirectionality.SWAP and b_val is not None: + # self.req_states[adx] = b_val def apply(self, logits: torch.Tensor) -> torch.Tensor: if not self.req_states: From fd9f557d48ddcf2d94845a72dd239fa9e79fb01e Mon Sep 17 00:00:00 2001 From: Wallas Santos Date: Fri, 17 Oct 2025 09:54:03 -0300 Subject: [PATCH 14/16] refact: code cleanup Signed-off-by: Wallas Santos --- tests/llm_cache.py | 3 +- vllm_spyre/v1/sample/golden_token_injector.py | 70 +++++++------------ 2 files changed, 27 insertions(+), 46 deletions(-) diff --git a/tests/llm_cache.py b/tests/llm_cache.py index be8a5ae0..c95e25d6 100644 --- a/tests/llm_cache.py +++ b/tests/llm_cache.py @@ -5,7 +5,6 @@ from typing import Callable, Generic, Optional, TypeVar import pytest -from vllm_spyre.v1.sample.golden_token_injector import GoldenTokenInjector from llm_cache_util import force_engine_shutdown from spyre_util import (DecodeWarmupShapes, ModelInfo, RemoteOpenAIServer, patch_environment) @@ -13,6 +12,8 @@ from vllm.v1.engine.core import EngineCore from vllm.v1.executor.abstract import Executor +from vllm_spyre.v1.sample.golden_token_injector import GoldenTokenInjector + T = TypeVar("T") ## class definitions ########################################## diff --git a/vllm_spyre/v1/sample/golden_token_injector.py b/vllm_spyre/v1/sample/golden_token_injector.py index bb5cf123..48d09d9e 100644 --- a/vllm_spyre/v1/sample/golden_token_injector.py +++ b/vllm_spyre/v1/sample/golden_token_injector.py @@ -1,16 +1,25 @@ import math -from typing import Optional +from typing import Optional, TYPE_CHECKING import json import torch import torch.nn.functional as F -from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import get_tokenizer -from vllm.v1.sample.logits_processor import process_dict_updates +from vllm.v1.sample.logits_processor import process_dict_updates, BatchUpdate + +from vllm.v1.sample.logits_processor import LogitsProcessor logger = init_logger(__name__) +if TYPE_CHECKING: + from vllm import SamplingParams + from vllm.config import VllmConfig + +else: + VllmConfig = None + SamplingParams = None + class ExpectationState: ''' This class controls the state of the generation. @@ -78,43 +87,7 @@ def add_req_states( return None def update_state(self, batch_update: Optional[BatchUpdate]): - process_dict_updates(self.req_states, batch_update, self.add_req_states) - # # This method keeps the indices consistent of request while the - # # persistent batch is changing. - # if not batch_update: - # return - - # # Process added requests. - # for index, params, _, _ in batch_update.added: - # assert params is not None - # if params.extra_args and ( - # injector_dict := - # params.extra_args.get("golden_token_injector")): - - # # OpenAI API can pass this parameter as string, so - # # we will just parse as the expected dict - # if isinstance(injector_dict, str): - # injector_dict = json.loads(injector_dict) - - # self.req_states[index] = ExpectationState(**injector_dict) - - # if not self.req_states: - # return - - # # Process removed requests. - # for index in batch_update.removed: - # self.req_states.pop(index, None) - - # # Process moved requests, unidirectional move (a->b) and swap - # # (a<->b) - # for adx, bdx, direct in batch_update.moved: - # a_val = self.req_states.pop(adx, None) - # b_val = self.req_states.pop(bdx, None) - # if a_val is not None: - # self.req_states[bdx] = a_val - # if direct == MoveDirectionality.SWAP and b_val is not None: - # self.req_states[adx] = b_val def apply(self, logits: torch.Tensor) -> torch.Tensor: if not self.req_states: @@ -181,7 +154,7 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: (full_prob - prob) / other_token_ids_count) if lp < other_logprobs: - print("The logprob is lower than the redistributed " + logger.warn("The logprob is lower than the redistributed " "logprobs for the token ids " f"({lp.item()} < {other_logprobs.item()}), this " "suggests that the generation diverged too much " @@ -197,11 +170,18 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: expected_token = self.tokenizer.decode([expected_token_id]) old_prob = logprobs[req_idx][token_id].exp().item() - print(f"Golden token injection for request {label}"\ - f" at token index '{expectation.current_token_idx}':") - print(f"'{token}' ({old_prob * 100:.2f}%) replaced by" - f" '{expected_token}' ({prob * 100:.2f}%);" - f" baseline: ({expected_prob * 100:.2f}%)") + logger.info("Golden token injection for request %s"\ + " at token index '%d': " + "'%s' (%.2f%%) replaced by " + "'%s' (%.2f%%);" + " baseline: (%.2f%%)", + label, + expectation.current_token_idx, + token, + old_prob * 100, + expected_token, + prob * 100, + expected_prob * 100) expectation.current_token_idx += 1 return logits From 0cf23ff4425ea6afbca35be1af3014337e422459 Mon Sep 17 00:00:00 2001 From: Wallas Santos Date: Fri, 17 Oct 2025 12:00:08 -0300 Subject: [PATCH 15/16] feat: improved gti Signed-off-by: Wallas Santos --- tests/llm_cache.py | 12 +- vllm_spyre/v1/sample/golden_token_injector.py | 244 ++++++++++-------- .../v1/sample/spyre_logits_processor.py | 5 + 3 files changed, 147 insertions(+), 114 deletions(-) diff --git a/tests/llm_cache.py b/tests/llm_cache.py index c95e25d6..2d027c1f 100644 --- a/tests/llm_cache.py +++ b/tests/llm_cache.py @@ -1,7 +1,6 @@ """Contains utilities for caching models (instantiated as vLLM endpoints) across test cases, to speed up test runtime.""" -import os from typing import Callable, Generic, Optional, TypeVar import pytest @@ -12,8 +11,6 @@ from vllm.v1.engine.core import EngineCore from vllm.v1.executor.abstract import Executor -from vllm_spyre.v1.sample.golden_token_injector import GoldenTokenInjector - T = TypeVar("T") ## class definitions ########################################## @@ -180,12 +177,6 @@ def get_engine( revision = None model_name = model - # Register golden token injector if not disabled - disable_golden_token = \ - bool(int(os.getenv("VLLM_SPYRE_TEST_DISABLE_GOLDEN_TOKEN", "0"))) - logits_processors = [] if disable_golden_token else \ - [GoldenTokenInjector] - # 🌶️🌶️🌶️ # Messing with the blocks and context length by either: # - setting context < 512 tokens @@ -205,8 +196,7 @@ def get_engine( max_model_len=max(max_model_len, 512), max_num_seqs=max_num_seqs_compiled, num_gpu_blocks_override=None, - revision=revision, - logits_processors=logits_processors) + revision=revision) vllm_config = engine_args.create_engine_config() executor_class = Executor.get_class(vllm_config) diff --git a/vllm_spyre/v1/sample/golden_token_injector.py b/vllm_spyre/v1/sample/golden_token_injector.py index 48d09d9e..e349f90a 100644 --- a/vllm_spyre/v1/sample/golden_token_injector.py +++ b/vllm_spyre/v1/sample/golden_token_injector.py @@ -1,14 +1,15 @@ +import json import math -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Optional, cast -import json import torch import torch.nn.functional as F from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import get_tokenizer -from vllm.v1.sample.logits_processor import process_dict_updates, BatchUpdate +from vllm.v1.sample.logits_processor import (BatchUpdate, LogitsProcessor, + process_dict_updates) -from vllm.v1.sample.logits_processor import LogitsProcessor +from vllm_spyre.v1.sample.spyre_logits_processor import SpyreLogitsProcessor logger = init_logger(__name__) @@ -20,6 +21,7 @@ VllmConfig = None SamplingParams = None + class ExpectationState: ''' This class controls the state of the generation. @@ -42,7 +44,7 @@ def __init__(self, label: Optional[str] = None): self.token_ids: list[int] = expected_token_ids - self.logprobs: Optional[list[float]] = expected_logprobs + self.logprobs: Optional[list[float]] = expected_logprobs self.threshold: Optional[float] = error_threshold self.label: Optional[str] = label @@ -50,44 +52,47 @@ def __init__(self, self.has_error = False -class GoldenTokenInjector(LogitsProcessor): +class GoldenTokenInjector(SpyreLogitsProcessor, LogitsProcessor): """Logit processor to inject expected token during generation for tests""" def __init__(self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool): self.req_states: dict[int, ExpectationState] = {} - # NOTE: This logit processor hold a tokenizer for each instance. - # for couple requests that does not have too much impact. - # But since this is used mostly for validation, it would be fine - # to keep them. self.tokenizer = get_tokenizer(vllm_config.model_config.tokenizer) + self.prefill_index: Optional[int] = None + def is_argmax_invariant(self) -> bool: """Never impacts greedy sampling""" return False - - + + def set_prefill(self, idx: int) -> None: + self.prefill_index = idx @staticmethod def add_req_states( - params: SamplingParams, prompt_tok_ids: list[int] | None, - output_tok_ids: list[int]) -> Optional[ExpectationState]: + params: SamplingParams, prompt_tok_ids: list[int] | None, + output_tok_ids: list[int]) -> Optional[ExpectationState]: if params.extra_args and ( - injector_dict := - params.extra_args.get("golden_token_injector")): + injector_dict := + params.extra_args.get("golden_token_injector")): # OpenAI API can pass this parameter as string, so # we will just parse as the expected dict if isinstance(injector_dict, str): injector_dict = json.loads(injector_dict) + elif not isinstance(injector_dict, dict): + raise ValueError( + "Golden token injector accepts only str or dict.") return ExpectationState(**injector_dict) return None def update_state(self, batch_update: Optional[BatchUpdate]): - process_dict_updates(self.req_states, batch_update, self.add_req_states) + process_dict_updates(self.req_states, batch_update, + self.add_req_states) def apply(self, logits: torch.Tensor) -> torch.Tensor: if not self.req_states: @@ -96,92 +101,125 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: # Calculate logprobs for the current model execution logprobs = F.log_softmax(logits, dim=-1) - for req_idx, expectation in self.req_states.items(): - - if expectation.has_error: - # There was an error already for inject tokens for this - # request, skip until the end of its generation. - continue - - expected_token_id = expectation.token_ids[ - expectation.current_token_idx] - token_id = torch.argmax(logits[req_idx], dim=-1) - - if expected_token_id == token_id: - # Expectation is met, nothing to do. - expectation.current_token_idx += 1 - continue - - # Get the logprob for the expected token - lp = logprobs[req_idx][expected_token_id].reshape(-1) - prob = torch.exp(lp).item() - - expected_logprob = \ - expectation.logprobs[expectation.current_token_idx] - expected_prob = math.exp(expected_logprob) - - # Label to identify request, if the label was set in the state, - # use it, otherwise it will be the index of the request in the - # batch - - label = f"'{expectation.label}'" if expectation.label is not None \ - else f"idx '{req_idx}'" - - # We'll inject only if the error is below the threshold - if not math.isclose( - expected_prob, prob, abs_tol=expectation.threshold): - err = abs(expected_prob - prob) - - print("Token probability is out of the acceptable threshold " - f"{err:.2f} > {expectation.threshold:.2f} at request " - f"{label} token idx '{expectation.current_token_idx}'." - " Token injection will be skipped.") - expectation.has_error = True - continue - - full_prob = torch.ones(1, dtype=logprobs.dtype) # 100% - - # Keep the same logprob for the expected token and - # redistribute evenly the probability among the other - # token ids. - # NOTE: we are setting logprobs to the logits, if we recalculate - # the softmax again over this distribution we shall find the same - # values, but with some minimal difference. The intention is - # inject the golden token but preserving the original logprob. - - other_token_ids_count = logits.shape[1] - 1 - other_logprobs = torch.log( - (full_prob - prob) / other_token_ids_count) - - if lp < other_logprobs: - logger.warn("The logprob is lower than the redistributed " - "logprobs for the token ids " - f"({lp.item()} < {other_logprobs.item()}), this " - "suggests that the generation diverged too much " - "from the expectation.") - expectation.has_error = True - continue - - logits[req_idx] = other_logprobs - logits[req_idx][expected_token_id] = lp - - # Decode the tokens for better human readability - token = self.tokenizer.decode([token_id]) - expected_token = self.tokenizer.decode([expected_token_id]) - old_prob = logprobs[req_idx][token_id].exp().item() + if self.prefill_index: + expectation = self.req_states[self.prefill_index] + # zero because for prefill there's only a request in the batch + self.inject_token(logits, logprobs, 0, expectation) + self.prefill_index = None + else: + for req_idx, expectation in self.req_states.items(): + self.inject_token(logits, logprobs, req_idx, expectation) - logger.info("Golden token injection for request %s"\ - " at token index '%d': " - "'%s' (%.2f%%) replaced by " - "'%s' (%.2f%%);" - " baseline: (%.2f%%)", - label, - expectation.current_token_idx, - token, - old_prob * 100, - expected_token, - prob * 100, - expected_prob * 100) + return logits + + def inject_token(self, logits: torch.Tensor, logprobs: torch.Tensor, + req_idx: int, expectation: ExpectationState): + if expectation.has_error: + # There was an error already for inject tokens for this + # request, skip until the end of its generation. + return + + expected_token_id = expectation.token_ids[ + expectation.current_token_idx] + token_id = torch.argmax(logits[req_idx], dim=-1) + + if expected_token_id == token_id: + # Expectation is met, nothing to do. expectation.current_token_idx += 1 + return - return logits + # Label to identify request, if the label was set in the state, + # use it, otherwise it will be the index of the request in the + # batch + + label = f"'{expectation.label}'" if expectation.label is not None \ + else f"idx '{req_idx}'" + + # Decode the tokens for better human readability + token = self.tokenizer.decode([token_id]) + expected_token = self.tokenizer.decode([expected_token_id]) + + if expectation.logprobs is None or \ + expectation.threshold is None: + + # Always inject the token + logits[req_idx] = -math.inf + logits[req_idx][expected_token_id] = 0.0 + + logger.info("Golden token injection for request %s"\ + " at token index '%d': " + "'%s' replaced by '%s'", + label, + expectation.current_token_idx, + token, + expected_token) + + return + + # Check if the token is injectable based on a threshold + token_lp = logprobs[req_idx][expected_token_id].reshape(-1) + prob = torch.exp(token_lp).item() + + expected_logprob = \ + cast(torch.Tensor, expectation.logprobs)[ + expectation.current_token_idx + ] + expected_prob = math.exp(expected_logprob) + + # We'll inject only if the error is below the threshold + if not math.isclose(expected_prob, + prob, + abs_tol=cast(float, expectation.threshold)): + err = abs(expected_prob - prob) + + logger.err( + "Token probability is out of the acceptable threshold " + "%.2f > %.2f at request " + "%s token idx '%s'." + " Token injection will be skipped.", err, + expectation.threshold, label, expectation.current_token_idx) + expectation.has_error = True + return + + full_prob = torch.ones(1, dtype=logprobs.dtype) # 100% + + # Keep the same logprob for the expected token and + # redistribute evenly the probability among the other + # token ids. + # NOTE: we are setting logprobs to the logits, if we recalculate + # the softmax again over this distribution we shall find the same + # values, but with some minimal difference. The intention is + # inject the golden token but preserving the original logprob. + + other_token_ids_count = logits.shape[1] - 1 + other_logprobs = torch.log((full_prob - prob) / other_token_ids_count) + + if token_lp < other_logprobs: + logger.warning( + "The logprob is lower than the redistributed " + "logprobs for the token ids " + "(%.4f < %.4f), this " + "suggests that the generation diverged too much " + "from the expectation.", token_lp.item(), + other_logprobs.item()) + expectation.has_error = True + return + + logits[req_idx] = other_logprobs + logits[req_idx][expected_token_id] = token_lp + + old_prob = logprobs[req_idx][token_id].exp().item() + + logger.info("Golden token injection for request %s"\ + " at token index '%d': " + "'%s' (%.2f%%) replaced by " + "'%s' (%.2f%%);" + " baseline: (%.2f%%)", + label, + expectation.current_token_idx, + token, + old_prob * 100, + expected_token, + prob * 100, + expected_prob * 100) + expectation.current_token_idx += 1 diff --git a/vllm_spyre/v1/sample/spyre_logits_processor.py b/vllm_spyre/v1/sample/spyre_logits_processor.py index 0700f7c5..5bc3f9ee 100644 --- a/vllm_spyre/v1/sample/spyre_logits_processor.py +++ b/vllm_spyre/v1/sample/spyre_logits_processor.py @@ -60,6 +60,9 @@ def build_logitsprocs_for_cb( logitprocs_classes = custom_logitsprocs_classes + builtin_logitsprocs + # To avoid circular import + from vllm_spyre.v1.sample.golden_token_injector import GoldenTokenInjector + return LogitsProcessors( itertools.chain( [SpyreLogitBiasLogitsProcessor(vllm_config, device, @@ -70,6 +73,8 @@ def build_logitsprocs_for_cb( SpyreMinTokensLogitsProcessor(vllm_config, device, is_pin_memory), + GoldenTokenInjector(vllm_config, device, is_pin_memory) + ], [LogitsProcessorWrapper(logit_processor, vllm_config, From 9d0514f5eca3c97c7c0e6eec6ad9fc66d77e6595 Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Mon, 20 Oct 2025 15:45:44 -0700 Subject: [PATCH 16/16] =?UTF-8?q?=F0=9F=90=9B=20make=20args=20optional?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- vllm_spyre/v1/sample/golden_token_injector.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm_spyre/v1/sample/golden_token_injector.py b/vllm_spyre/v1/sample/golden_token_injector.py index e349f90a..7daf7d18 100644 --- a/vllm_spyre/v1/sample/golden_token_injector.py +++ b/vllm_spyre/v1/sample/golden_token_injector.py @@ -39,8 +39,8 @@ class ExpectationState: def __init__(self, expected_token_ids: list[int], - expected_logprobs: Optional[list[float]], - error_threshold: Optional[float], + expected_logprobs: Optional[list[float]] = None, + error_threshold: Optional[float] = None, label: Optional[str] = None): self.token_ids: list[int] = expected_token_ids