Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add XTR scoring #86

Open
robro612 opened this issue Feb 3, 2025 · 4 comments
Open

Add XTR scoring #86

robro612 opened this issue Feb 3, 2025 · 4 comments
Labels
enhancement New feature or request

Comments

@robro612
Copy link

robro612 commented Feb 3, 2025

Recently WARP has reminded us about and improved upon XTR's modified inference scheme which obviates the need for loading every embedding for retrieved candidate documents by imputing the missing query-doc scores as the minimum score seen in the retrieved scores for that query token.

This change however requires an analogous modification of the scoring function during training, only keeping query-document token scores that are in the top k_train of all in-batch document tokens for a given query-token.

This change should be somewhat straightforward, mainly with one or multiple (depending on the multiplicity of documents to queries) xtr_scores functions in scores/scores.py.

I'd like to take this on if that's alright with y'all @NohTow @raphaelsty . It could be a good prelude to adding WARP support after the PLAID implementation is up and running.

@raphaelsty
Copy link
Collaborator

raphaelsty commented Feb 3, 2025

Amazing @robro612, this would be awesome. Feel free to ask any questions or ask for early review if you want to ! If I understand well we need Plaid and running to have Warp retriever ? 🚀

@robro612
Copy link
Author

robro612 commented Feb 3, 2025

This issue will allow people to use pylate to train XTR models, and the naive (non-WARP) retrieval is pretty straightforward.

I haven't done a super deep dive on the WARP code; it uses many of the same CoLBERT structures, but I can't speak to how interchangeable WARP and PLAID would be at an implementation level.

@NohTow NohTow added the enhancement New feature or request label Feb 4, 2025
@NohTow
Copy link
Collaborator

NohTow commented Feb 4, 2025

Hello @robro612, thanks for the initiative, appreciate it!

As precised in #83, I am interested in adding PLAID, XTR as well as maybe MUVERA to PyLate, so would be definitely appreciated. I think a dedicated loss function would be adequate for training, and we can obviously add specific scoring method as well if required.

That's for training. For indexation, as mentioned in #83, someone offered to help a bit for PLAID/MUVERA implementation without relying on stanford-nlp code (native). Until this happens (if it does), maybe we can go with what I started to build on the PLAID branch ; which basically copy and paste stanford-nlp lib and offload the modeling part to PyLate. I can confirm that this works fine, as I am using the branch to run evals lately, it could actually be merged, I just need to remove the code from stanford-nlp that we do not need and also add it properly to the retrieve function (right now it works by just calling the index, because retrieve do first a candidate generation and then reranking, while PLAID handle everything).

I saw that the code from WARP is hugely based on PLAID (did not have time to read the paper yet), so I think we could just apply the same changes to the modeling I did for PLAID to the code of WARP and it should work.
The reason I am not directly using the indexes from PLAID is because we need to take the modeling out of the indexes to be in line with PyLate, and this is hard/impossible to do by extending the code. Do you think it's worth to goes towards a WARP branch more than the PLAID one?

Edit: I will start by extracting a diff patch of the stanford-code with the one of the branch, will be a good start for WARP.

@robro612
Copy link
Author

robro612 commented Feb 4, 2025

Thanks to the score_function argument of the loss functions, I think XTR training could be implemented entirely at the scoring level, but I'll have to do some more exploring/thinking about how negatives are handled in pylate to be sure.

As for the WARP/PLAID implementation, my interpretation of the WARP paper is that it could subsume PLAID if in the decompression stage it loads the clusters/residuals of all retrieved documents' tokens (feasible given that we can associate with docids) and in the scoring stage (the reductions) just doesn't need to perform missing score imputation. Right now though, my understanding of how directly the code reflects these stages (particularly the reduction) is less complete.

Overall though, WARP does offer some XTR-agnostic improvements over PLAID:

“This refined scoring function is far more efficient then the one outlined in PLAID,9 as it never computes the decompressed embeddings explicitly and instead directly emits the resulting candidate scores.” [Scheerer et al., 2025, p. 5]

so if it's the case that there's already a significant effort in extricating the retrieval algorithm from the modeling in PLAID, it would be good to jump directly to reimplementing WARP.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants