Qualcomm: cap inf replacement value to fix 16a16w accuracy regression#20471
Qualcomm: cap inf replacement value to fix 16a16w accuracy regression#20471psiddh wants to merge 4 commits into
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20471
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ⏳ No Failures, 12 PendingAs of commit 8875528 with merge base da9158b ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
|
This PR needs a
|
PR pytorch#19660 folded ReplaceInfValues into QnnQuantizer._replace_inf and made the inf stand-in equal to the full quant range. For 16a16w that is 65535 (vs the previous fixed 255), which blows up the attention-mask quant scale and breaks stories110M decoding in test-llama-runner-qnn-linux. Cap the magnitude at 255 to restore prior behavior; 8a8w is unaffected.
58aed0e to
2333727
Compare
There was a problem hiding this comment.
Pull request overview
Caps the stand-in value used when replacing ±inf constants during QNN quantization annotation so that higher-bit activation quantization (notably 16a16w) doesn’t inflate attention-mask dynamic range and degrade Llama decoding accuracy.
Changes:
- Cap
_get_quant_range()’s returned value to<= 255to prevent large integer ranges (e.g., uint16) from dominating observer calibration. - Add an explanatory code comment documenting the rationale and the Llama attention-mask motivation.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # Cap the inf stand-in so it does not dominate the tensor's | ||
| # dynamic range. For >8-bit activations the full range (e.g. | ||
| # 65535 for uint16) would blow up the attention-mask quant scale | ||
| # and wreck accuracy; 255 keeps a reasonable scale for | ||
| # Llama-style attention masks. |
| # 65535 for uint16) would blow up the attention-mask quant scale | ||
| # and wreck accuracy; 255 keeps a reasonable scale for | ||
| # Llama-style attention masks. | ||
| return min(quant_range, 255) |
There was a problem hiding this comment.
Sure, please go ahead and validate on CI, happy to wait for your results.
On the hardcoded 255: agree we should avoid magic numbers. perhaps a cleaner version would be to derive the cap from the int8 range rather than hardcoding: ..
return min(quant_range, torch.iinfo(torch.int8).max - torch.iinfo(torch.int8).min)
There was a problem hiding this comment.
@winskuo-quic Any updates ? We would like to land this asap as this is stalling viable/strict branch from moving forward.
There was a problem hiding this comment.
@winskuo-quic The CI runner for test-llama-runner-qnn-linux (qnn_16a16w) is OOM-killing the export step, exit code 137 (SIGKILL), The job is never getting far enough , the Python export process gets killed before it finishes , it looks like.
There was a problem hiding this comment.
Hi @psiddh,
Thanks for sharing all the test results.
I am also encountering the same issue.
I noticed it is killed during qnn_preprocess. I believe there's a chance OOM happened inside QNN SDK, which is weird as this is a small model.
I am still working on finding a solution to resolve the issue.
If you happened to find a work around, please feel free to merge first to unblock the CI error.
Thanks again for all the help.
There was a problem hiding this comment.
Here is the quick work-around : #20511 , This will give us time to investigate OOM issue
There was a problem hiding this comment.
Thanks for the workaround. I will try to find re-enable 16a16w test.
67f78b6 to
bf8bdbf
Compare
The 16a16w export+compile is OOM-killed on linux.2xlarge (23 min) and runner crashes on linux.4xlarge.memory (93 min). Use linux.8xlarge.memory to verify whether the accuracy fix works once memory is sufficient.
bf8bdbf to
8875528
Compare
| fail-fast: false | ||
| with: | ||
| runner: linux.2xlarge | ||
| runner: linux.8xlarge.memory |
PR #19660 folded ReplaceInfValues into QnnQuantizer._replace_inf and made the inf stand-in equal to the full quant range. For 16a16w that is 65535 (vs the previous fixed 255), which blows up the attention-mask quant scale and breaks stories110M decoding in test-llama-runner-qnn-linux. Cap the magnitude at 255 to restore prior behavior; 8a8w is unaffected.