-
Notifications
You must be signed in to change notification settings - Fork 320
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Request to support FlashAttention in cuda attention.cc #1300
Comments
Hi, |
Hi @guillaumekln! How hard would it be to implement this? Can you maybe give us some pointers? Most of the popular libraries support v1.0 already |
Hi, |
At this time FlashAttention is mostly useful for training or when processing a long prompt. However, during inference most of the time is usually spent in the iterative decoding where the bottleneck is somewhere else. It seems the author is still working to optimize this inference part where the input shape is different: Dao-AILab/flash-attention#346 (comment) Also I can't find an end-to-end benchmark using FlashAttention for inference. Do you know where we can find one? I only see benchmarks for end-to-end training or for the attention module itself. |
But I believe it will be able to reduce the VRAM usage further. Can we get some support to run memory efficient attention? |
Hello all. Just thought I'd post a question about Flash Attention 2 here: https://github.com/Dao-AILab/flash-attention Apparently it's making big waves and seems seems very powerful. Does anyone plan on seeing if it's something that could be included? I reviewed the prior comments and suggest that we change the topic to Flash Attention 2. I know that guillaumkeln is no longer with faster-whisper, but hopefully one of the admins can weight in on this possibly powerful feature to include in ctranslate2!! |
Hi my thought on this, they are some major pros and some cons : Pro :
Cons :
Is it possible to have your thoughts on this dev ? |
Eventually this will be included, but it is not the same story to include a pip package (and we did include flash2 in OpenNMT-py) and link a cpp package that is moving quite frequently. Of course we don't want to drop the current path for scaled dot attention. bear in mind that native pytorch is not dead: |
Some other repo claimed flash attention will be helpful to make transcribe much faster: https://github.com/Vaibhavs10/insanely-fast-whisper |
Ctranslate2 supports soon the flash attention 2 following this PR #1651. I will do the release asap. I made some tests and saw an improvement in performance with long prompt. It would run on GPU architecture >= sm80 only as mentioned in the original repo. It would be great if you guys could test it. |
thanks, looking forward to test it with faster-whisper! |
This is great! Any chance you could provide some tips as to how to test this on faster-whisper? |
Make sure you have Ampere GPU or newer. You can just set |
Hi @minhthuc2502, |
Hello, I did not make a benchmark with Faster Whisper, but there is some benchmark for Flash Attention with some LLM models here. |
I haven't benched Whisper in relation to flash attention, but my hypothesis is that it will not make much of a difference for a beam size of 1 but that it "might" if the beam size is increased. However, the benefit will likely not be nearly as great as with stereotypical chat models. I deduce this conclusion based on the following:
Keep in mind that I just haven't had time to test this. In my testing I try to honestly represent peoples' hard work, but I'm not a programmer by trade and this is a hobby of mine so...take it with a grain of salt. Hope this helps! |
Hi @BBC-Esq, @minhthuc2502 The reason I was asking about faster whisper FA benchmark is that I do not see any improvement in speed when loading the whisper model with FA. Here is the code I used to benchmark: import time
import torch
from faster_whisper import WhisperModel
def benchmark(model):
times = []
# Warmup
for i in range(10):
segments, _ = model.transcribe(
"sample_1.wav",
language="fr",
)
segments = list(segments)
# Benchmark
for i in range(100):
segments, _ = model.transcribe(
"sample_1.wav",
language="fr"
)
past = time.time()
segments = list(segments)
torch.cuda.synchronize()
times.append(time.time() - past)
times = times[1:]
print(f"Mean inference time: {sum(times) / len(times)}")
print(f"\nTIMES: {times}")
if __name__ == '__main__':
# model = WhisperModel("/home/user/whisper-large-v2-ct2", flash_attention=True)
model = WhisperModel("/home/user/whisper-large-v2-ct2", flash_attention=False)
benchmark(model) The results for the above code snip are (after running it twice independently): About the setup:
Is this result expected? if not what can be done to make it faster? |
@AvivSham If you're asking for my opinion on how to speed things up generally, faster-whisper has a pull request for batch processing that's not yet approved. If you don't want to wait for it you can use the But if you're asking how to make it faster with flash attention, based on the assumption that you might not be using flash attention correctly with faster-whisper...afraid I can't really help. @minhthuc2502 might be able to help, but what I've learned is that those kinds of questions are better posted on the With that being said, I can confirm that flash attention works for "chat" models so I'd be surprised if there's some kind of core issue with the ctranslate2 library that prevents it from working just with Whisper models... |
BTW, when I said "I can't really help" it's not that I don't want to...it's just that I'm tapped out as far as my personal knowledge...Programming is only hobby for me after all. ;-) |
@AvivSham You might also test your script using beam sizes 1-5 and see if there's a difference? If there's a noticeable difference between using flash attention and not, you could perhaps eliminate the variable that somehow the flash attention parameter isn't being used at all? At the end of this discussion they do confirm that flash attention can be used... |
Thank you for your attempt to help! 😄 I will post this question directly in the |
For more information, I executed some benchmarks for Faster whisper with FlashAttention in here. |
With recent tests, I posted a benchmarks with FA2, I noticed that with longer sequence length, I can see more obviously the difference between FA2 and standard MHA. Otherwise, in case of faster whisper, the 30 seconds audio chunk will be converted to an encoder's input with the shape (1,80,3000), see here. The sequence length is quite small to get the benefit of FA2. |
FlashAttention can largely avoid memory usage and speeds up attention even in the process of inference. Any plan to support this implementation:
https://github.com/facebookresearch/xformers/tree/main/xformers/csrc/attention
The text was updated successfully, but these errors were encountered: