Skip to content

Update the local attention mask logic to work on MPS and CUDA in ModernBERT #561

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from

Conversation

kozistr
Copy link
Contributor

@kozistr kozistr commented Apr 5, 2025

What does this PR do?

In the previous PR #459, local_attention can only be worked on CPU due to the abs() operation.

So, I've made a change to calculate the window mask in pure Rust and then create the Tensor from it. Maybe, this could make ModernBERT run on both MPS and CUDA too.

I've checked the output with this script and it seems identical to before.

tested devices

  • tested on CPU (local machine, WSL2)
  • tested on T4 GPU (Kaggle notebook, TEI server runs successfully)
  • tested on MPS (might work, but needs to be tested)

performance (get_window_mask())

window size seq len latency (p50)
64 8192 20 ms
64 4096 6 ms
64 1024 160 us

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@Narsil, @alvarobartt, @ivarflakstad

@kozistr kozistr changed the title Update the logic for local attetion mask in ModernBERT Update the local attention mask logic to work on MPS and CUDA in ModernBERT Apr 8, 2025
@Narsil
Copy link
Collaborator

Narsil commented Apr 8, 2025

@kozistr I think everything already works on MPS, it's been checked by @ErikKaum #562

During the flash fix I forced the mask creation in CPU candle (to enable non-flash CUDA execution), and that fixed metal as well.
I think keeping everything in candle instead of Rust, makes it slightly nicer, because we could switch to on-device execution if we simply created the missing kernels (which is relatively easy to do).

No matter what, having a mask that is only created in the first part of the model is acceptable imho.

@Narsil
Copy link
Collaborator

Narsil commented Apr 8, 2025

That being said, we could definitely clean up a little bit that code imho.

The HashMap<bool, XX> could be replaced with a simple struct Mask{local: Tensor, global: Tensor} which should increase readability (and the overhead of a HashMap is not negligible even though I doubt it can be measure in this particular instance).

@kozistr
Copy link
Contributor Author

kozistr commented Apr 8, 2025

@kozistr I think everything already works on MPS, it's been checked by @ErikKaum #562

During the flash fix I forced the mask creation in CPU candle (to enable non-flash CUDA execution), and that fixed metal as well. I think keeping everything in candle instead of Rust, makes it slightly nicer, because we could switch to on-device execution if we simply created the missing kernels (which is relatively easy to do).

No matter what, having a mask that is only created in the first part of the model is acceptable IMHO.

I thought there were still some issues on MPS, but I missed that PR. Thanks for checking on that.
And I agree with your point that keeping everything in Candle would be great!

That being said, we could definitely clean up a little bit that code imho.

The HashMap<bool, XX> could be replaced with a simple struct Mask{local: Tensor, global: Tensor} which should increase readability (and the overhead of a HashMap is not negligible even though I doubt it can be measure in this particular instance).

I second this. We could refactor the mask part with struct instead of HashMap regarding readability.

Then, I'm gonna close this PR and make another contribution later!

Thanks for your time to review this PR :)

@kozistr kozistr closed this Apr 8, 2025
@kozistr kozistr deleted the refactor/local-attention branch April 8, 2025 09:29
@Narsil
Copy link
Collaborator

Narsil commented Apr 8, 2025

Thanks a lot to you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants