Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flash attention #1651

Merged
merged 3 commits into from
Apr 9, 2024
Merged

Conversation

minhthuc2502
Copy link
Collaborator

@minhthuc2502 minhthuc2502 commented Mar 28, 2024

I currently work on the integration for flash attention. It based on the kernel developed in original repo.

For the current version, I don't see any improvement in performance compared with the standard MHA. Tested with GPU A100.

Something to consider here:

  • Implement flash attention with kvcache as fused.
  • Accept loss in perf by additional transpose operation for qkv to align the same shape between standard and flash attention

Update:

  • With long prompt, we have a good improvement in performance.
  • The size of package increases after the integration of feature
  • Compiling time increases due to the heavy use of template (but it improves the perf in runtime)

@minhthuc2502 minhthuc2502 added help wanted Extra attention is needed and removed help wanted Extra attention is needed labels Mar 28, 2024
@BBC-Esq
Copy link

BBC-Esq commented Apr 2, 2024

Super excited about this. Let me know if you need someone with Windows and an RTX 4090 to test.

@Purfview
Copy link

Purfview commented Apr 4, 2024

For the current version, I don't see any improvement in performance

For performance you probably should switch to cuDNN 9.x

From cuDNN 9.0.0 release notes:

FP16 and BF16 fused flash attention engine performance has been significantly improved for NVIDIA GPUs:
    Speed-up of up to 50% over cuDNN 8.9.7 on Hopper GPUs.
    Speed-up of up to 100% over cuDNN 8.9.7 on Ampere GPUs.

@minhthuc2502
Copy link
Collaborator Author

minhthuc2502 commented Apr 4, 2024

For performance you probably should switch to cuDNN 9.x

I think the cuDNN is currently used only for conv ops, so it wouldn't affect the perf of the flash-attention. By the way, I see an improvement when trying with long length of input. I will do some more tests and improve the time of compilation.

@minhthuc2502 minhthuc2502 force-pushed the dev/flash_attention branch from 0314a82 to 845a327 Compare April 8, 2024 09:31
@minhthuc2502 minhthuc2502 force-pushed the dev/flash_attention branch from d22ca71 to e6e8f95 Compare April 8, 2024 16:03
@minhthuc2502 minhthuc2502 changed the title [WIP] flash attention Flash attention Apr 9, 2024
@minhthuc2502 minhthuc2502 merged commit 7d63eea into OpenNMT:master Apr 9, 2024
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants