-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add XTR scoring #86
Comments
Amazing @robro612, this would be awesome. Feel free to ask any questions or ask for early review if you want to ! If I understand well we need Plaid and running to have Warp retriever ? 🚀 |
This issue will allow people to use pylate to train XTR models, and the naive (non-WARP) retrieval is pretty straightforward. I haven't done a super deep dive on the WARP code; it uses many of the same CoLBERT structures, but I can't speak to how interchangeable WARP and PLAID would be at an implementation level. |
Hello @robro612, thanks for the initiative, appreciate it! As precised in #83, I am interested in adding PLAID, XTR as well as maybe MUVERA to PyLate, so would be definitely appreciated. I think a dedicated loss function would be adequate for training, and we can obviously add specific scoring method as well if required. That's for training. For indexation, as mentioned in #83, someone offered to help a bit for PLAID/MUVERA implementation without relying on stanford-nlp code (native). Until this happens (if it does), maybe we can go with what I started to build on the PLAID branch ; which basically copy and paste stanford-nlp lib and offload the modeling part to PyLate. I can confirm that this works fine, as I am using the branch to run evals lately, it could actually be merged, I just need to remove the code from stanford-nlp that we do not need and also add it properly to the retrieve function (right now it works by just calling the index, because retrieve do first a candidate generation and then reranking, while PLAID handle everything). I saw that the code from WARP is hugely based on PLAID (did not have time to read the paper yet), so I think we could just apply the same changes to the modeling I did for PLAID to the code of WARP and it should work. Edit: I will start by extracting a diff patch of the stanford-code with the one of the branch, will be a good start for WARP. |
Thanks to the As for the WARP/PLAID implementation, my interpretation of the WARP paper is that it could subsume PLAID if in the decompression stage it loads the clusters/residuals of all retrieved documents' tokens (feasible given that we can associate with docids) and in the scoring stage (the reductions) just doesn't need to perform missing score imputation. Right now though, my understanding of how directly the code reflects these stages (particularly the reduction) is less complete. Overall though, WARP does offer some XTR-agnostic improvements over PLAID:
so if it's the case that there's already a significant effort in extricating the retrieval algorithm from the modeling in PLAID, it would be good to jump directly to reimplementing WARP. |
Recently WARP has reminded us about and improved upon XTR's modified inference scheme which obviates the need for loading every embedding for retrieved candidate documents by imputing the missing query-doc scores as the minimum score seen in the retrieved scores for that query token.
This change however requires an analogous modification of the scoring function during training, only keeping query-document token scores that are in the top
k_train
of all in-batch document tokens for a given query-token.This change should be somewhat straightforward, mainly with one or multiple (depending on the multiplicity of documents to queries)
xtr_scores
functions inscores/scores.py
.I'd like to take this on if that's alright with y'all @NohTow @raphaelsty . It could be a good prelude to adding WARP support after the PLAID implementation is up and running.
The text was updated successfully, but these errors were encountered: