-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flash Transformers modeling backend support #2913
Conversation
@@ -219,7 +218,7 @@ def filter(self, indices): | |||
return None | |||
|
|||
|
|||
class HeterogeneousTopPLogitsWarper(LogitsWarper): | |||
class HeterogeneousTopPLogitsWarper(LogitsProcessor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why those changes ?
It looks like these are Warpers, not Processors.... No ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Recent Transformers version deprecated LogitsWarper
in favor of LogitsProcessor
. They have the same logic/interface
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LogitsWarper
has been completely removed, thus it needs to change here
server/text_generation_server/models/transformers_flash_causal_lm.py
Outdated
Show resolved
Hide resolved
return logits | ||
|
||
|
||
def forward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To save rewriting all of this, can we maybe monkey path model.forward
.
It's not great, but it still feels better than this (which is duplicating quite complex logic that we're likely to forget to update).
in the __init__
: self.model.forward = monkey_patch_forward
(that does the simple unsqueeze, squeeze).
Wdyt ?
The other option would be to introduce an indirection in FlashCausalLM
by having an inner method that captures that model.forward
call and we override it (I'm not a fan either to add such level of indirection, it's just as ugly as the monkey patching in my book, because it's still impossible for the reader to know that it's being modified somewhere else).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I thought about it but was not sure you guys would like it 🙃 I'll update with the monkey patch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks for this !
What does this PR do?
Supersedes #2858 for convenience (CIs) as requested (I cannot just change the origin branch from the old PR)