Skip to content

Conversation

wallashss
Copy link
Collaborator

Description

This PR includes builtin logits processors for spyre. These logits processors are the same from vllm but more optimized to not be wrapped by the LogitsProcessorWrapper that will slice logits at each engine step. These logits processors work by calling the set_prefill_index and properly handling prefill in our spyre model runner. This PR also extends tests of sampling params to run with continuous batching for the parameters that use logits processors under the hood.

Signed-off-by: Wallas Santos <[email protected]>
Signed-off-by: Wallas Santos <[email protected]>
Signed-off-by: Wallas Santos <[email protected]>
self.input_batch.refresh_metadata()
else:
# Due to logits processor we need to refresh metadata at each step
self.input_batch.refresh_metadata()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tjohnson31415 please see this.

Copy link

👋 Hi! Thank you for contributing to vLLM support on Spyre.
Just a reminder: Make sure that your code passes all the linting checks, otherwise your PR won't be able to be merged. To do so, first install the linting requirements, then run format.sh and commit the changes. This can be done with uv directly:

uv sync --frozen --group lint --active --inexact

Or this can be done with pip:

uv pip compile --group lint > requirements-lint.txt
pip install -r requirements-lint.txt
bash format.sh

Now you are good to go 🚀

Signed-off-by: Wallas Santos <[email protected]>
Signed-off-by: Wallas Santos <[email protected]>
def test_spyre_batch1_logit_bias(model: ModelInfo, backend, monkeypatch,
use_llm_cache, warmup_shapes):
use_llm_cache, warmup_shapes, max_model_len,
max_num_seqs, cb: int):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thoughts on swapping these tests to continuous batching only and not testing at all for static batching?

Currently this file takes about 10 minutes to run for static batching, and I'm not sure that it makes sense to do given that we're only focusing on improvements to continuous batching

warmup_shapes=warmup_shapes,
)
warmup_shapes=warmup_shapes if cb == 0 else None,
use_cb=cb == 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

while we're in here, I think the token_diversity check could be sped up by

  • using the n parameter instead of a for loop for batched decodes
  • setting the random seed to a fixed value and using n < 10

Running 20 separate batches for one test takes quite a long time on github actions 🐌🐌🐌

@wallashss
Copy link
Collaborator Author

bot:test

Comment on lines 150 to 165
# Convert logits to probability distribution
probability_values = torch.nn.functional.softmax(logits, dim=-1)
# Calculate maximum probabilities per sequence
max_probabilities = torch.amax(probability_values,
dim=-1,
keepdim=True)
# Adjust min_p
adjusted_min_p = max_probabilities.mul_(
self.min_p[self._prefill_index].unsqueeze(0))
# Identify valid tokens using threshold comparison
invalid_token_mask = probability_values < adjusted_min_p
# Apply mask using boolean indexing
logits[invalid_token_mask] = -float('inf')
self._prefill_index = None

return logits
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is identical to the super class apply(). Isn't it better to just call super().apply()?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to, but the self.min_p contains data of other requests that I filter them out with self._prefill_index to get only the request being prefilled.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I created a new class PrefillHelperLogitsProcessor: it will instantiate two logits processor, one for prefill and other for decoding, and our builtin logits processor just reuse the existing implementation on vllm. The class is more efficient than the logitswrapper, but it only works if the state between prefill and decode are independent. It won't work for golden token injection for example. So I think I solved the code deduplication in vllm spyre.

Copy link
Collaborator

@tjohnson31415 tjohnson31415 Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but it only works if the state between prefill and decode are independent

I think a well-behaved LogitsProcessor doesn't need persistent state between prefill and decode, i.e. it can be created from a request after output tokens have been generated. This would be needed to support resuming generation after preemption.

Could GTI be updated to remove the persistent state, i.e. to get its state from the content of the batch_update?
current_token_idx could be set based on the length of the current output_tokens and has_error if the expected and output tokens don't match or something like that.

Signed-off-by: Wallas Santos <[email protected]>
Signed-off-by: Wallas Santos <[email protected]>
@wallashss
Copy link
Collaborator Author

bot:test

1 similar comment
@wallashss
Copy link
Collaborator Author

bot:test

class SpyreLogitsProcessor:

def set_prefill(self, idx: int) -> None:
raise NotImplementedError
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the GoldenTokenInjector implement SpyreLogitsProcessor and move the state from prefill to decode in the set_prefill method?

self.logitsprocs_wrappers = [lp for lp \
in self.logitsprocs.all if isinstance(lp, LogitProcessorWrapper)
self.spyre_logitsprocs = [lp for lp \
in self.logitsprocs.all if isinstance(lp, SpyreLogitsProcessor)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to require that all LogitsProcessors be SpyreLogitsProcessors?

Comment on lines 150 to 165
# Convert logits to probability distribution
probability_values = torch.nn.functional.softmax(logits, dim=-1)
# Calculate maximum probabilities per sequence
max_probabilities = torch.amax(probability_values,
dim=-1,
keepdim=True)
# Adjust min_p
adjusted_min_p = max_probabilities.mul_(
self.min_p[self._prefill_index].unsqueeze(0))
# Identify valid tokens using threshold comparison
invalid_token_mask = probability_values < adjusted_min_p
# Apply mask using boolean indexing
logits[invalid_token_mask] = -float('inf')
self._prefill_index = None

return logits
Copy link
Collaborator

@tjohnson31415 tjohnson31415 Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but it only works if the state between prefill and decode are independent

I think a well-behaved LogitsProcessor doesn't need persistent state between prefill and decode, i.e. it can be created from a request after output tokens have been generated. This would be needed to support resuming generation after preemption.

Could GTI be updated to remove the persistent state, i.e. to get its state from the content of the batch_update?
current_token_idx could be set based on the length of the current output_tokens and has_error if the expected and output tokens don't match or something like that.

Signed-off-by: Travis Johnson <[email protected]>
Copy link
Collaborator

@tjohnson31415 tjohnson31415 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maxdebayser found that this PR has a bug that can be reproduced using the approach in #508

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants