Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flash Transformers modeling backend support #2913

Merged
merged 33 commits into from
Jan 21, 2025
Merged

Flash Transformers modeling backend support #2913

merged 33 commits into from
Jan 21, 2025

Conversation

Cyrilvallez
Copy link
Member

@Cyrilvallez Cyrilvallez commented Jan 15, 2025

What does this PR do?

Supersedes #2858 for convenience (CIs) as requested (I cannot just change the origin branch from the old PR)

@@ -219,7 +218,7 @@ def filter(self, indices):
return None


class HeterogeneousTopPLogitsWarper(LogitsWarper):
class HeterogeneousTopPLogitsWarper(LogitsProcessor):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why those changes ?

It looks like these are Warpers, not Processors.... No ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recent Transformers version deprecated LogitsWarper in favor of LogitsProcessor. They have the same logic/interface

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LogitsWarper has been completely removed, thus it needs to change here

return logits


def forward(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To save rewriting all of this, can we maybe monkey path model.forward.

It's not great, but it still feels better than this (which is duplicating quite complex logic that we're likely to forget to update).

in the __init__: self.model.forward = monkey_patch_forward (that does the simple unsqueeze, squeeze).

Wdyt ?

The other option would be to introduce an indirection in FlashCausalLM by having an inner method that captures that model.forwardcall and we override it (I'm not a fan either to add such level of indirection, it's just as ugly as the monkey patching in my book, because it's still impossible for the reader to know that it's being modified somewhere else).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I thought about it but was not sure you guys would like it 🙃 I'll update with the monkey patch

@Cyrilvallez Cyrilvallez changed the title Transformers backend Flash Transformers modeling backend support Jan 20, 2025
Narsil
Narsil previously approved these changes Jan 20, 2025
Copy link
Collaborator

@Narsil Narsil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for this !

@Narsil Narsil merged commit b980848 into main Jan 21, 2025
13 of 14 checks passed
@Narsil Narsil deleted the transformers-backend branch January 21, 2025 09:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants