Multi Label Precision numerical stability issues with MPS #2955
Labels
bug / fix
Something isn't working
duplicate
This issue or pull request already exists
help wanted
Extra attention is needed
v1.6.x
🐛 Bug
Currently trying to track metrics of a rather complicated ECG multi-label classifier (20 labels) and noticed that training with MPS breaks the MultiLabelPrecision in an odd way, so that the returned values are clearly outside of any possible scope e.g. -1.49272e+25. First I thought it is a numerical issue on the logits or some weird input shape issue I am not handling properly, but after doing some further digging I noticed that the inconsistency happens with device MPS specifically and can be reproduced. I couldn't get to the bottom of what leads to this as it is inconsistent, but most likely numerical.
Would love some guidance on this!
To Reproduce
Below code snippet with manually calculated expected outputs. I hope my math isn't wrong, but nonetheless, if I set device to MPS and run the code a few times (may work on some runs) I will eventually run into some very odd numerical issues on test 3 and 4:
Code sample
Expected behavior
Test 3 and 4 should return this using MPS:
=== Exactly 5 labels correct per sample ===
DEBUG:main:Computed precision: 0.25
✓ PASSED - Final Precision: 0.2500
=== 5 correct + 5 false positives ===
DEBUG:main:Computed precision: 0.25
✓ PASSED - Final Precision: 0.2500
instead of something like:
=== Exactly 5 labels correct per sample ===
DEBUG:main:Computed precision: 16985032228864.0
✗ FAILED: Expected 0.25, got 16985032228864.0000
=== 5 correct + 5 false positives ===
DEBUG:main:Computed precision: -1.4927448875054142e+25
✗ FAILED: Expected 0.25, got -14927448875054142008066048.0000
Environment
The text was updated successfully, but these errors were encountered: