-
Notifications
You must be signed in to change notification settings - Fork 26
feat: improve spyre logits processors for CB #527
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Wallas Santos <[email protected]>
Signed-off-by: Wallas Santos <[email protected]>
Signed-off-by: Wallas Santos <[email protected]>
Signed-off-by: Wallas Santos <[email protected]>
Signed-off-by: Wallas Santos <[email protected]>
self.input_batch.refresh_metadata() | ||
else: | ||
# Due to logits processor we need to refresh metadata at each step | ||
self.input_batch.refresh_metadata() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tjohnson31415 please see this.
👋 Hi! Thank you for contributing to vLLM support on Spyre.
Or this can be done with
Now you are good to go 🚀 |
Signed-off-by: Wallas Santos <[email protected]>
Signed-off-by: Wallas Santos <[email protected]>
def test_spyre_batch1_logit_bias(model: ModelInfo, backend, monkeypatch, | ||
use_llm_cache, warmup_shapes): | ||
use_llm_cache, warmup_shapes, max_model_len, | ||
max_num_seqs, cb: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thoughts on swapping these tests to continuous batching only and not testing at all for static batching?
Currently this file takes about 10 minutes to run for static batching, and I'm not sure that it makes sense to do given that we're only focusing on improvements to continuous batching
warmup_shapes=warmup_shapes, | ||
) | ||
warmup_shapes=warmup_shapes if cb == 0 else None, | ||
use_cb=cb == 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
while we're in here, I think the token_diversity
check could be sped up by
- using the
n
parameter instead of a for loop for batched decodes - setting the random seed to a fixed value and using
n < 10
Running 20 separate batches for one test takes quite a long time on github actions 🐌🐌🐌
Signed-off-by: Wallas Santos <[email protected]>
Signed-off-by: Wallas Santos <[email protected]>
bot:test |
# Convert logits to probability distribution | ||
probability_values = torch.nn.functional.softmax(logits, dim=-1) | ||
# Calculate maximum probabilities per sequence | ||
max_probabilities = torch.amax(probability_values, | ||
dim=-1, | ||
keepdim=True) | ||
# Adjust min_p | ||
adjusted_min_p = max_probabilities.mul_( | ||
self.min_p[self._prefill_index].unsqueeze(0)) | ||
# Identify valid tokens using threshold comparison | ||
invalid_token_mask = probability_values < adjusted_min_p | ||
# Apply mask using boolean indexing | ||
logits[invalid_token_mask] = -float('inf') | ||
self._prefill_index = None | ||
|
||
return logits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code is identical to the super class apply()
. Isn't it better to just call super().apply()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd like to, but the self.min_p
contains data of other requests that I filter them out with self._prefill_index
to get only the request being prefilled.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I created a new class PrefillHelperLogitsProcessor
: it will instantiate two logits processor, one for prefill and other for decoding, and our builtin logits processor just reuse the existing implementation on vllm. The class is more efficient than the logitswrapper, but it only works if the state between prefill and decode are independent. It won't work for golden token injection for example. So I think I solved the code deduplication in vllm spyre.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but it only works if the state between prefill and decode are independent
I think a well-behaved LogitsProcessor doesn't need persistent state between prefill and decode, i.e. it can be created from a request after output tokens have been generated. This would be needed to support resuming generation after preemption.
Could GTI be updated to remove the persistent state, i.e. to get its state from the content of the batch_update?
current_token_idx
could be set based on the length of the current output_tokens and has_error
if the expected and output tokens don't match or something like that.
Signed-off-by: Wallas Santos <[email protected]>
Signed-off-by: Wallas Santos <[email protected]>
Signed-off-by: Wallas Santos <[email protected]>
bot:test |
1 similar comment
bot:test |
class SpyreLogitsProcessor: | ||
|
||
def set_prefill(self, idx: int) -> None: | ||
raise NotImplementedError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can the GoldenTokenInjector implement SpyreLogitsProcessor and move the state from prefill to decode in the set_prefill
method?
self.logitsprocs_wrappers = [lp for lp \ | ||
in self.logitsprocs.all if isinstance(lp, LogitProcessorWrapper) | ||
self.spyre_logitsprocs = [lp for lp \ | ||
in self.logitsprocs.all if isinstance(lp, SpyreLogitsProcessor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to require that all LogitsProcessors be SpyreLogitsProcessors?
# Convert logits to probability distribution | ||
probability_values = torch.nn.functional.softmax(logits, dim=-1) | ||
# Calculate maximum probabilities per sequence | ||
max_probabilities = torch.amax(probability_values, | ||
dim=-1, | ||
keepdim=True) | ||
# Adjust min_p | ||
adjusted_min_p = max_probabilities.mul_( | ||
self.min_p[self._prefill_index].unsqueeze(0)) | ||
# Identify valid tokens using threshold comparison | ||
invalid_token_mask = probability_values < adjusted_min_p | ||
# Apply mask using boolean indexing | ||
logits[invalid_token_mask] = -float('inf') | ||
self._prefill_index = None | ||
|
||
return logits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but it only works if the state between prefill and decode are independent
I think a well-behaved LogitsProcessor doesn't need persistent state between prefill and decode, i.e. it can be created from a request after output tokens have been generated. This would be needed to support resuming generation after preemption.
Could GTI be updated to remove the persistent state, i.e. to get its state from the content of the batch_update?
current_token_idx
could be set based on the length of the current output_tokens and has_error
if the expected and output tokens don't match or something like that.
Signed-off-by: Travis Johnson <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@maxdebayser found that this PR has a bug that can be reproduced using the approach in #508
Description
This PR includes builtin logits processors for spyre. These logits processors are the same from vllm but more optimized to not be wrapped by the
LogitsProcessorWrapper
that will slice logits at each engine step. These logits processors work by calling theset_prefill_index
and properly handling prefill in our spyre model runner. This PR also extends tests of sampling params to run with continuous batching for the parameters that use logits processors under the hood.