-
Notifications
You must be signed in to change notification settings - Fork 315
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Numerical discrepancy in reduced precision operations causes WER degradation for custom models #1845
Comments
I am also fighting accuracy problems after conversion and I have been doing some experiments as well but I get exactly the same numbers for the convolutions (I modified my ctranslate's whisper implementation as well and i am adding back specific operations). I believe you should explicitly specify the parameters for the convolutions in your torch experiments:
|
This is just how BLAS libs are, they take shortcuts and treat fp ops as associative. I ran your torch equivalent on my cpu and actually got the same 5 numbers as your ct2 result. Compared cuda/cpu with torch for fp32/fp16/bf16 and they were all different (fp32 had RMSD ~1e-7). I've been finetuning my own whisper models and never had problems with conversion (that wasn't PEBKAC). |
@sssshhhhhh yes might just be BLAS shortcuts and I would expect some instabilities, but 1e-4 feels like to large of a diff. I have compared the weights that are loaded and what i can see they are the same. I convert my finetuned model in pt format with the following steps
Still might PEBKAC, can you see anything that is incorrect with the above steps? |
I don't think it's that bad considering bf16 only has 7 bits of mantissa. In the range [0.015625, 0.03125) 1e-4 is the spacing between values. Even within torch with bf16 I get an RMSD of 1e-3. Your steps looks right and I'm not saying your issue is PEBKAC since there's lots of ways to train which might hit some edge case. But I do think focusing on these numerical deviations is barking up the wrong tree. |
yes for a single layer this might be fine. But the error enlarges as we go through the encoder so the final encoded features are quite off. The start of this investigation was that for some files the correct token would not even be in the top 5 (but be the most activated in the openai implementation) and then I worked backwards and found this to be the earliest diff. I can create and example of this too if you want to? But might just be that our finetuned model is very sensitive to the precision used |
An example where tokens are completely different might help. I tested a random finetune (jlvdoorn/whisper-large-v3-atco2-asr) and had no problems with it either.
I compared the fp16 encoder output of a random sample. RMSD between oai and ct2 was 5e-3 but the decoded output was identical. |
First posted this issue on faster-whisper but they suggested that I should post it here since it is likely an issue with ctranslate2
Description
When using custom fine-tuned models, faster-whisper's implementation shows significant WER degradation compared to OpenAI's reference implementation (13.5% vs 8.2% WER). Through investigation, I've traced this to numerical differences starting from the very first Conv1D operation in the encoder and is also present for the original large-v3 weights. The weights are both stored in float16, suggesting some numerical or algorithmic issue.
Here is some benchmark results on custom dataset.
After comparing the logits of the two implementations and trying to narrow down the root cause of the issue i manage to locate that the difference starts as early the first conv1d operation in the encoder.
Steps to reproduce
Modifying the whisper encoder operator function to only apply the conv1d operation. I.e:
And running the following scripts will show the diff between the implementation.
Faster whisper
pytorch equivalent is
If i run these two scripts i get the following outputs:
faster-whisper:
pytorch
Environment
python 3.10.10
CUDA Version: 12.3
All running inside a docker container: nvcr.io/nvidia/pytorch:23.10-py3
GPU: 3090
Additional findings
Precision behavior:
Input dependency (bfloat16):
The text was updated successfully, but these errors were encountered: