diff --git a/tests/e2e/test_sampling_params.py b/tests/e2e/test_sampling_params.py index 7f2e6d3f..dae9b540 100644 --- a/tests/e2e/test_sampling_params.py +++ b/tests/e2e/test_sampling_params.py @@ -164,8 +164,11 @@ def test_spyre_batch1_n_generations(model: ModelInfo, backend, monkeypatch, def token_diversity(spyre_model, prompt, params, n_experiments): tokens = [] - for i in range(n_experiments): - output = spyre_model.generate(prompt, params)[0] + + outputs = spyre_model.generate([prompt] * n_experiments, + params, + use_tqdm=False) + for output in outputs: tokens.extend(output.outputs[0].token_ids) return len(set(tokens)) @@ -210,15 +213,17 @@ def test_spyre_batch1_top_k(model: ModelInfo, backend, monkeypatch, def test_spyre_batch1_logit_bias(model: ModelInfo, backend, monkeypatch, - use_llm_cache, warmup_shapes): + use_llm_cache, warmup_shapes, max_model_len, + max_num_seqs, cb: int): spyre_model = get_cached_llm( model=model, - max_model_len=128, + max_model_len=max_model_len, + max_num_seqs=max_num_seqs, tensor_parallel_size=1, backend=backend, monkeypatch=monkeypatch, - warmup_shapes=warmup_shapes, - ) + warmup_shapes=warmup_shapes if cb == 0 else None, + use_cb=cb == 1) tokenizer = spyre_model.get_tokenizer() banned_word = "train" forced_word = "plane" @@ -239,25 +244,26 @@ def test_spyre_batch1_logit_bias(model: ModelInfo, backend, monkeypatch, }) params2 = SamplingParams(temperature=0, seed=8780, max_tokens=5) - output1 = spyre_model.generate(prompt, params1)[0] - output2 = spyre_model.generate(prompt, params2)[0] + output = spyre_model.generate([prompt, prompt], [params1, params2]) - assert banned_word not in output1.outputs[0].text.lower() - assert forced_word in output1.outputs[0].text.lower() + assert banned_word not in output[0].outputs[0].text.lower() + assert forced_word in output[0].outputs[0].text.lower() - assert output1.outputs[0].text != output2.outputs[0].text + assert output[0].outputs[0].text != output[1].outputs[0].text def test_spyre_batch1_min_tokens(model: ModelInfo, backend, monkeypatch, - use_llm_cache, warmup_shapes): + use_llm_cache, max_model_len, max_num_seqs, + warmup_shapes, cb: int): spyre_model = get_cached_llm( model=model, - max_model_len=128, + max_model_len=max_model_len, tensor_parallel_size=1, backend=backend, monkeypatch=monkeypatch, - warmup_shapes=warmup_shapes, - ) + warmup_shapes=warmup_shapes if cb != 1 else None, + max_num_seqs=max_num_seqs if cb == 1 else None, + use_cb=cb == 1) prompt = "What is the capital of the USA?" tokenizer = spyre_model.get_tokenizer() eos_id = tokenizer.eos_token_id @@ -268,11 +274,10 @@ def test_spyre_batch1_min_tokens(model: ModelInfo, backend, monkeypatch, max_tokens=20) params2 = SamplingParams(seed=8780, logit_bias={eos_id: 50}, max_tokens=20) - output1 = spyre_model.generate(prompt, params1)[0] - output2 = spyre_model.generate(prompt, params2)[0] + output = spyre_model.generate([prompt] * 2, [params1, params2]) - assert len(output1.outputs[0].token_ids) >= 19 - assert len(output2.outputs[0].token_ids) < 19 + assert len(output[0].outputs[0].token_ids) >= 19 + assert len(output[1].outputs[0].token_ids) < 19 def test_spyre_batch1_ignore_eos(model: ModelInfo, backend, monkeypatch, @@ -310,15 +315,17 @@ def test_spyre_batch1_ignore_eos(model: ModelInfo, backend, monkeypatch, def test_spyre_batch1_min_p(model: ModelInfo, backend, monkeypatch, - use_llm_cache, warmup_shapes): + use_llm_cache, max_model_len, max_num_seqs, + warmup_shapes, cb: int): spyre_model = get_cached_llm( model=model, - max_model_len=128, + max_model_len=max_model_len, + max_num_seqs=max_num_seqs, tensor_parallel_size=1, backend=backend, monkeypatch=monkeypatch, - warmup_shapes=warmup_shapes, - ) + warmup_shapes=warmup_shapes if cb == 0 else None, + use_cb=cb == 1) prompt = "The opposite of black is" params1 = SamplingParams(min_p=0.5, temperature=1, max_tokens=5) params2 = SamplingParams(temperature=1, max_tokens=5) diff --git a/tests/utils/test_spyre_logits_processor.py b/tests/utils/test_spyre_logits_processor.py new file mode 100644 index 00000000..fda34f9c --- /dev/null +++ b/tests/utils/test_spyre_logits_processor.py @@ -0,0 +1,312 @@ +import math + +import pytest +import torch +from vllm.sampling_params import SamplingParams +from vllm.v1.sample.logits_processor import (BatchUpdateBuilder, + MinPLogitsProcessor) + +from vllm_spyre.v1.sample.spyre_logits_processor import ( + SpyreLogitBiasLogitsProcessor, SpyreLogitsProcessor, + SpyreMinPLogitsProcessor, SpyreMinTokensLogitsProcessor) + +EOS_TOKEN = 3 +VOCAB_SIZE = 8 + + +class DummyVllmConfig: + + def __init__(self): + self.scheduler_config = DummySchedulerConfig() + + +class DummySchedulerConfig: + + def __init__(self, max_num_seqs=4): + self.max_num_seqs = max_num_seqs + + +def generate_logits(batch_size: int = 1): + + return torch.tensor(list(range(VOCAB_SIZE)) * batch_size, + dtype=torch.float32).reshape((batch_size, VOCAB_SIZE)) + + +def prefill(params: SamplingParams, batch_update_builder: BatchUpdateBuilder, + lp: SpyreLogitsProcessor, logits: torch.Tensor, + output_tokens: list[int], req_idx: int, num_reqs: int): + params._all_stop_token_ids = set([EOS_TOKEN]) # + prompt_tokens = [0] * 8 + batch_update_builder.added.append( + (req_idx, params, prompt_tokens, output_tokens)) + batch_update = batch_update_builder.get_and_reset(num_reqs) + lp.update_state(batch_update) + + if isinstance(lp, SpyreLogitsProcessor): + lp.set_prefill(req_idx) + + out_logits = lp.apply(logits.clone()) + output_tokens.append(0) # just append a random token + + return out_logits + + +def decode(batch_update_builder: BatchUpdateBuilder, lp: SpyreLogitsProcessor, + logits: torch.Tensor, batch_output_tokens: list[list[int]]): + + assert logits.shape[0] == len(batch_output_tokens) + + # This is called at each execute model in spyre model runner update_states + lp.update_state(None) + + out_logits = lp.apply(logits.clone()) + + for output_tokens in batch_output_tokens: + output_tokens.append(0) # just append a random token + + return out_logits + + +@pytest.mark.cpu +@pytest.mark.worker +def test_mintokens_logits_processor(): + """ + This method tests the SpyreMinTokensLogitsProcessor class. + + The test case simulates partially the step of the engine with focus on + logits processor only. + The logits processor should mark the EOS token as -inf until it reaches the + minimum number of tokens specified in the SamplingParams. + """ + device = torch.device('cpu') + + dummy_config = DummyVllmConfig() + lp = SpyreMinTokensLogitsProcessor(dummy_config, device, False) + + batch_update_builder = BatchUpdateBuilder() + + batch_output_tokens = [[], [], []] + + # Step #0 Prefill req_id #0 (no min tokens) + logits = generate_logits(1) + out_logits = prefill(SamplingParams(), + batch_update_builder, + lp, + logits, + batch_output_tokens[0], + req_idx=0, + num_reqs=1) + + assert torch.allclose(logits, out_logits) + + # Step #1 Prefill req_id #1 (with min tokens) + params = SamplingParams(min_tokens=4) + + logits = generate_logits(1) + out_logits = prefill(params, + batch_update_builder, + lp, + logits, + batch_output_tokens[1], + req_idx=1, + num_reqs=2) + + assert out_logits[0][EOS_TOKEN].item() == -math.inf + + # Step #2 Prefill req_id #2 + logits = generate_logits(1) + out_logits = prefill(SamplingParams(), + batch_update_builder, + lp, + logits, + batch_output_tokens[2], + req_idx=2, + num_reqs=3) + + assert torch.allclose(logits, out_logits) + + # Step #3 - #6 Decoding, eos_token for req #1 must be -inf + for _ in range(3): + logits = generate_logits(3) + out_logits = decode( + batch_update_builder, + lp, + logits, + batch_output_tokens, + ) + + assert torch.allclose(logits[0], out_logits[0]) + assert torch.allclose(logits[2], out_logits[2]) + assert out_logits[1][EOS_TOKEN].item() == -math.inf + + # Step #7, min tokens reached, no changes in logits anymore + logits = generate_logits(3) + out_logits = decode( + batch_update_builder, + lp, + logits, + batch_output_tokens, + ) + + assert torch.allclose(logits, out_logits) + + +@pytest.mark.cpu +@pytest.mark.worker +def test_logitbias_logits_processor(): + device = torch.device('cpu') + + dummy_config = DummyVllmConfig() + lp = SpyreLogitBiasLogitsProcessor(dummy_config, device, False) + + batch_update_builder = BatchUpdateBuilder() + + batch_output_tokens = [[], [], []] + + # Step #0 Prefill req_id #0 (no logits bias) + logits = generate_logits(1) + out_logits = prefill(SamplingParams(), + batch_update_builder, + lp, + logits, + batch_output_tokens[0], + req_idx=0, + num_reqs=1) + + assert torch.allclose(logits, out_logits) + + # Step #1 Prefill req_id #1 (with logits bias) + params = SamplingParams(logit_bias={1: 100, 7: -100}) + + logits = generate_logits(1) + out_logits = prefill(params, + batch_update_builder, + lp, + logits, + batch_output_tokens[1], + req_idx=1, + num_reqs=2) + + assert out_logits[0][1].item() >= 100 + assert out_logits[0][7].item() < 0 + + # Step #2 Prefill req_id #2 (no logits bias) + logits = generate_logits(1) + out_logits = prefill(SamplingParams(), + batch_update_builder, + lp, + logits, + batch_output_tokens[2], + req_idx=2, + num_reqs=3) + + assert torch.allclose(logits, out_logits) + + # Step #4, decoding, keep applying logits bias + logits = generate_logits(3) + out_logits = decode( + batch_update_builder, + lp, + logits, + batch_output_tokens, + ) + + assert torch.allclose(logits[0], out_logits[0]) + assert torch.allclose(logits[2], out_logits[2]) + assert out_logits[1][1].item() >= 100 + assert out_logits[1][7].item() < 0 + + +@pytest.mark.cpu +@pytest.mark.worker +def test_minp_logits_processor(): + device = torch.device('cpu') + + dummy_config = DummyVllmConfig() + lp = SpyreMinPLogitsProcessor(dummy_config, device, False) + # vllm logits processor + vllm_lp = MinPLogitsProcessor(dummy_config, device, False) + + batch_update_builder = BatchUpdateBuilder() + + batch_output_tokens = [[], [], []] + + # Step #0 Prefill req_id #0 (no min p) + logits = generate_logits(1) + out_logits = prefill(SamplingParams(), + batch_update_builder, + lp, + logits, + batch_output_tokens[0], + req_idx=0, + num_reqs=1) + vllm_logits = prefill(SamplingParams(), + batch_update_builder, + vllm_lp, + logits, + batch_output_tokens[0], + req_idx=0, + num_reqs=1) + + assert torch.allclose(vllm_logits, out_logits) + + # Step #1 Prefill req_id #1 (with logits bias) + params = SamplingParams(min_p=0.15) + + logits = generate_logits(1) + out_logits = prefill(params, + batch_update_builder, + lp, + logits, + batch_output_tokens[1], + req_idx=1, + num_reqs=2) + logits = generate_logits(2) + vllm_logits = prefill(params, + batch_update_builder, + vllm_lp, + logits, + batch_output_tokens[1], + req_idx=1, + num_reqs=2) + + assert torch.allclose(vllm_logits[1], out_logits[0]) + + # Step #2 Prefill req_id #2 (no min p) + logits = generate_logits(1) + out_logits = prefill(SamplingParams(), + batch_update_builder, + lp, + logits, + batch_output_tokens[2], + req_idx=2, + num_reqs=3) + + logits = generate_logits(3) + vllm_logits = prefill(SamplingParams(), + batch_update_builder, + vllm_lp, + logits, + batch_output_tokens[2], + req_idx=2, + num_reqs=3) + + assert torch.allclose(logits, out_logits) + + # Step #4, decoding, keep applying min p + logits = generate_logits(3) + out_logits = decode( + batch_update_builder, + lp, + logits, + batch_output_tokens, + ) + + vllm_logits = decode( + batch_update_builder, + vllm_lp, + logits, + batch_output_tokens, + ) + + assert torch.allclose(out_logits, vllm_logits) diff --git a/vllm_spyre/v1/sample/spyre_logits_processor.py b/vllm_spyre/v1/sample/spyre_logits_processor.py index ec6621f7..0700f7c5 100644 --- a/vllm_spyre/v1/sample/spyre_logits_processor.py +++ b/vllm_spyre/v1/sample/spyre_logits_processor.py @@ -1,26 +1,51 @@ import itertools -from typing import Optional, Sequence, Union +from typing import TYPE_CHECKING, Optional, Sequence, Union import torch -from vllm.config import VllmConfig from vllm.logger import init_logger +# yapf: disable from vllm.v1.sample.logits_processor import (BUILTIN_LOGITS_PROCESSORS, STR_POOLING_REJECTS_LOGITSPROCS, - BatchUpdate, LogitsProcessor, - _load_custom_logitsprocs) + BatchUpdate, + LogitBiasLogitsProcessor, + LogitsProcessor, + MinPLogitsProcessor, + MinTokensLogitsProcessor, + _load_custom_logitsprocs, + process_dict_updates) +# yapf: enable from vllm.v1.sample.logits_processor.state import LogitsProcessors logger = init_logger(__name__) +if TYPE_CHECKING: + from vllm import SamplingParams + from vllm.config import VllmConfig +else: + SamplingParams = None + VllmConfig = None + +SPYRE_BUILTIN_LOGITS_PROCESSORS = [ + MinPLogitsProcessor, MinTokensLogitsProcessor, LogitBiasLogitsProcessor +] + def build_logitsprocs_for_cb( - vllm_config: "VllmConfig", + vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool, is_pooling_model: bool, batch_size: int, custom_logitsprocs: Sequence[Union[str, type[LogitsProcessor]]] = (), ) -> LogitsProcessors: + + if len(BUILTIN_LOGITS_PROCESSORS) > 3: + logger.warning( + "There are %d logits processors, which is unexpected " + "for this vllm-spyre version. Consider upgrade " + "vllm-spyre or open an issue to investigate this", + len(BUILTIN_LOGITS_PROCESSORS)) + if is_pooling_model: if custom_logitsprocs: raise ValueError(STR_POOLING_REJECTS_LOGITSPROCS) @@ -29,21 +54,107 @@ def build_logitsprocs_for_cb( return LogitsProcessors() custom_logitsprocs_classes = _load_custom_logitsprocs(custom_logitsprocs) - return LogitsProcessors( - LogitProcessorWrapper(logit_processor, + # Collect builtin LPs to fallback to the wrapper + builtin_logitsprocs = [lp for lp in BUILTIN_LOGITS_PROCESSORS \ + if lp not in SPYRE_BUILTIN_LOGITS_PROCESSORS] + + logitprocs_classes = custom_logitsprocs_classes + builtin_logitsprocs + + return LogitsProcessors( itertools.chain( + [SpyreLogitBiasLogitsProcessor(vllm_config, + device, + is_pin_memory), + SpyreMinPLogitsProcessor(vllm_config, + device, + is_pin_memory), + SpyreMinTokensLogitsProcessor(vllm_config, + device, + is_pin_memory), + ], + [LogitsProcessorWrapper(logit_processor, vllm_config, device, is_pin_memory, batch_size) \ - for logit_processor in itertools.chain( - BUILTIN_LOGITS_PROCESSORS, - custom_logitsprocs_classes - ) - ) + for logit_processor in logitprocs_classes] + )) + + +class SpyreLogitsProcessor: + + def set_prefill(self, idx: int) -> None: + raise NotImplementedError + + +class PrefillHelperLogitsProcessor(LogitsProcessor, SpyreLogitsProcessor): + """ + Logits processor (LP) that separates two instances of a concrete LPS: + one for the prefill, and other for the batch. This class only works if + the state of the LP is independent between prefill and decoding. for + example this class is not suitable for the golden token injector LP. + """ + + def __init__(self, config: VllmConfig, device: torch.device, + is_pin_memory: bool, logit_processor: LogitsProcessor): + self._prefill_lp : LogitsProcessor = \ + logit_processor(config, device, is_pin_memory) + self._batch_lp : LogitsProcessor = \ + logit_processor(config, device, is_pin_memory) + + self._is_prefill: bool = False + + # This dictionary stores the sampling parameters of `update_state` so + # we can get when we call `set_prefill` to proper setup the prefill_lp. + self._params: dict[int, tuple[SamplingParams, list[int], + list[int]]] = {} + + def is_argmax_invariant(self) -> bool: + """Never impacts greedy sampling""" + return self._batch_lp.is_argmax_invariant() + + @staticmethod + def update_batch_params( + params: SamplingParams, prompt_tok_ids: list[int] | None, + output_tok_ids: list[int] + ) -> tuple[SamplingParams, Sequence[int] | None, Sequence[int]] | None: + return params, prompt_tok_ids, output_tok_ids + def update_state(self, batch_update: BatchUpdate | None): -class LogitProcessorWrapper(LogitsProcessor): - """Logit processor to inject expected token during generation for tests""" + process_dict_updates(self._params, batch_update, + self.update_batch_params) + + # Always pass to the batch LP + self._batch_lp.update_state(batch_update) + + def set_prefill(self, idx: int) -> None: + + params, prompt_tok_ids, out_tok_ids = self._params[idx] + self._prefill_lp.update_state( + BatchUpdate( + batch_size=1, + removed=[], + moved=[], + added=[(0, params, prompt_tok_ids, out_tok_ids)], + )) + # self._params.pop(idx) + self._is_prefill = True + + def apply(self, logits: torch.Tensor) -> torch.Tensor: + if self._is_prefill: + logits = self._prefill_lp.apply(logits) + + # Clean the prefill LP + self._is_prefill = False + self._prefill_lp.update_state( + BatchUpdate(batch_size=1, removed=[0], moved=[], added=[])) + return logits + + return self._batch_lp.apply(logits) + + +class LogitsProcessorWrapper(LogitsProcessor, SpyreLogitsProcessor): + """Logit processor to isolate logits processors to run individually""" def __init__(self, logit_processor: LogitsProcessor, vllm_config: VllmConfig, device: torch.device, @@ -99,5 +210,29 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: return logits - def set_prefill_index(self, idx: int) -> None: + def set_prefill(self, idx: int) -> None: self._prefill_index = idx + + +class SpyreMinPLogitsProcessor(PrefillHelperLogitsProcessor): + + def __init__(self, vllm_config: VllmConfig, device: torch.device, + is_pin_memory: bool): + super().__init__(vllm_config, device, is_pin_memory, + MinPLogitsProcessor) + + +class SpyreLogitBiasLogitsProcessor(PrefillHelperLogitsProcessor): + + def __init__(self, vllm_config: VllmConfig, device: torch.device, + is_pin_memory: bool): + super().__init__(vllm_config, device, is_pin_memory, + LogitBiasLogitsProcessor) + + +class SpyreMinTokensLogitsProcessor(PrefillHelperLogitsProcessor): + + def __init__(self, vllm_config: VllmConfig, device: torch.device, + is_pin_memory: bool): + super().__init__(vllm_config, device, is_pin_memory, + MinTokensLogitsProcessor) diff --git a/vllm_spyre/v1/worker/spyre_input_batch.py b/vllm_spyre/v1/worker/spyre_input_batch.py index fbdd0d15..89a1f3c6 100644 --- a/vllm_spyre/v1/worker/spyre_input_batch.py +++ b/vllm_spyre/v1/worker/spyre_input_batch.py @@ -18,7 +18,7 @@ MoveDirectionality) from vllm.v1.sample.metadata import SamplingMetadata -from vllm_spyre.v1.sample.spyre_logits_processor import LogitProcessorWrapper +from vllm_spyre.v1.sample.spyre_logits_processor import SpyreLogitsProcessor @dataclass @@ -301,8 +301,8 @@ def __init__(self, self.batch_update_builder = BatchUpdateBuilder() self.logitsprocs = logitsprocs or LogitsProcessors() - self.logitsprocs_wrappers = [lp for lp \ - in self.logitsprocs.all if isinstance(lp, LogitProcessorWrapper) + self.spyre_logitsprocs = [lp for lp \ + in self.logitsprocs.all if isinstance(lp, SpyreLogitsProcessor) ] self.has_allowed_token_ids: set[str] = set() diff --git a/vllm_spyre/v1/worker/spyre_model_runner.py b/vllm_spyre/v1/worker/spyre_model_runner.py index 40e57b43..72edf4b3 100644 --- a/vllm_spyre/v1/worker/spyre_model_runner.py +++ b/vllm_spyre/v1/worker/spyre_model_runner.py @@ -404,6 +404,9 @@ def update_states(self, scheduler_output: SchedulerOutput): # of logitprocs. Refactor so that we can batch removals to the # `input_batch` self.input_batch.refresh_metadata() + else: + # Due to logits processor we need to refresh metadata at each step + self.input_batch.refresh_metadata() def _get_prompt_logprobs_dict( self, @@ -1011,14 +1014,14 @@ def _prepare_prompt( prefill_index = self.input_batch.add_request(req_state) self.prefill_batch.add_request(req_state) - # set prefill index for logits processor - for logitsproc in self.input_batch.logitsprocs_wrappers: - logitsproc.set_prefill_index(prefill_index) - # Refresh sampling metadata after all request are added to the batch self.input_batch.refresh_metadata() self.prefill_batch.refresh_metadata() + # set prefill index for logits processor + for logitsproc in self.input_batch.spyre_logitsprocs: + logitsproc.set_prefill(prefill_index) + self.model.indices = torch.ones(1, dtype=torch.bool, device='cpu') slot_mapping = torch.tensor([slots], dtype=torch.int64) prompt_token_ids_tensor = torch.tensor(prompt_token_ids,