-
Notifications
You must be signed in to change notification settings - Fork 17
Add ListMLE Loss #130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add ListMLE Loss #130
Conversation
|
Not sure why tests did not run, commenting so that tests run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delay, did one pass! Some questions below, let me know what you think.
I'll take a closer look again!
keras_rs/api/layers/__init__.py
Outdated
| from keras_rs.src.layers.embedding.distributed_embedding import DistributedEmbedding as DistributedEmbedding | ||
| from keras_rs.src.layers.embedding.distributed_embedding_config import FeatureConfig as FeatureConfig | ||
| from keras_rs.src.layers.embedding.distributed_embedding_config import TableConfig as TableConfig | ||
| from keras_rs.src.layers.embedding.embed_reduce import EmbedReduce as EmbedReduce | ||
| from keras_rs.src.layers.feature_interaction.dot_interaction import DotInteraction as DotInteraction | ||
| from keras_rs.src.layers.feature_interaction.feature_cross import FeatureCross as FeatureCross | ||
| from keras_rs.src.layers.retrieval.brute_force_retrieval import BruteForceRetrieval as BruteForceRetrieval | ||
| from keras_rs.src.layers.retrieval.hard_negative_mining import HardNegativeMining as HardNegativeMining | ||
| from keras_rs.src.layers.retrieval.remove_accidental_hits import RemoveAccidentalHits as RemoveAccidentalHits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm, these shouldn't be getting re-formatted, in my opinion. Are you on the correct Ruff version? Same for other files
ef5ee24 to
63388e9
Compare
|
Hi @abheesht17, Done suggested changes. Could you please review the PR. Thank you |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This LGTM! I'm going to drop an approval on this, but could you please format the code first?
|
Not sure why we don't have a way to run tests here, need to figure that out before merging |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Noticed a few nits. Instead of using +, -, *, /, let's use Keras ops everywhere
|
@LakshmiKalaKadali - can you look into the test failures above? |
|
Hi @abheesht17, The torch test fails are due to precision mismatch in between actual and expected values. That precision mismatch might be due to the mask AFTER exp(). Applied the mask to set invalid positions to (-1e9) before calling ops.exp(), and removed the masking that happens after exp(). I hope that this resolves the torch test failure. Could you please run the tests. Thank You |
|
The version of torch was updated for the tests. Please rebase to see if it helps. |
e986ef5 to
3da3fa2
Compare
|
I do not understand why the results are different. But the differences you see are significant enough that it has to be a bug. Can you try to debug by printing all intermediate values in |
|
Sure @hertschuh, I will deep dive into the intermediate values. Thank You |
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
157f774 to
8cb2d98
Compare
Added ListMLELoss code to listwise ranking. This code does not consider Lambda weights.
Here is the gist for verified results with TFRS ListMLELoss