-
Notifications
You must be signed in to change notification settings - Fork 250
Update the local attention mask
logic to work on MPS and CUDA in ModernBERT
#561
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
local attention mask
logic to work on MPS and CUDA in ModernBERT
@kozistr I think everything already works on MPS, it's been checked by @ErikKaum #562 During the flash fix I forced the mask creation in CPU candle (to enable non-flash CUDA execution), and that fixed metal as well. No matter what, having a mask that is only created in the first part of the model is acceptable imho. |
That being said, we could definitely clean up a little bit that code imho. The HashMap<bool, XX> could be replaced with a simple |
I thought there were still some issues on MPS, but I missed that PR. Thanks for checking on that.
I second this. We could refactor the mask part with Then, I'm gonna close this PR and make another contribution later! Thanks for your time to review this PR :) |
Thanks a lot to you. |
What does this PR do?
In the previous PR #459,
local_attention
can only be worked onCPU
due to theabs()
operation.So, I've made a change to calculate the
window mask
in pure Rust and then create the Tensor from it. Maybe, this could makeModernBERT
run on both MPS and CUDA too.I've checked the output with this script and it seems identical to before.
tested devices
performance (
get_window_mask()
)Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@Narsil, @alvarobartt, @ivarflakstad