Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SOLR-15873: simplify LTR[Interleaving]Rescorer code #475

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 38 additions & 20 deletions solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@
import org.apache.lucene.index.ReaderUtil;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Rescorer;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.Weight;
import org.apache.solr.ltr.interleaving.OriginalRankingLTRScoringQuery;
import org.apache.solr.search.ReRankRescorer;
import org.apache.solr.search.SolrIndexSearcher;


Expand All @@ -41,7 +41,7 @@
* new score to each document. The top documents will be resorted based on the
* new score.
* */
public class LTRRescorer extends Rescorer {
public class LTRRescorer extends ReRankRescorer {

final private LTRScoringQuery scoringQuery;

Expand Down Expand Up @@ -110,36 +110,32 @@ protected static void heapify(ScoreDoc[] hits, int size) {
}

/**
* rescores the documents:
* rescores all the documents:
*
* @param searcher
* current IndexSearcher
* @param firstPassTopDocs
* documents to rerank;
* @param topN
* documents to return;
*/
@Override
public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs,
int topN) throws IOException {
if ((topN == 0) || (firstPassTopDocs.scoreDocs.length == 0)) {
public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs) throws IOException {
if (firstPassTopDocs.scoreDocs.length == 0) {
return firstPassTopDocs;
}
final ScoreDoc[] firstPassResults = getFirstPassDocsRanked(firstPassTopDocs);
topN = Math.toIntExact(Math.min(topN, firstPassTopDocs.totalHits.value));

final ScoreDoc[] reranked = rerank(searcher, topN, firstPassResults);
final ScoreDoc[] reranked = rerank(searcher, firstPassResults);

return new TopDocs(firstPassTopDocs.totalHits, reranked);
}

private ScoreDoc[] rerank(IndexSearcher searcher, int topN, ScoreDoc[] firstPassResults) throws IOException {
final ScoreDoc[] reranked = new ScoreDoc[topN];
private ScoreDoc[] rerank(IndexSearcher searcher, ScoreDoc[] firstPassResults) throws IOException {
final ScoreDoc[] reranked = new ScoreDoc[firstPassResults.length];
final List<LeafReaderContext> leaves = searcher.getIndexReader().leaves();
final LTRScoringQuery.ModelWeight modelWeight = (LTRScoringQuery.ModelWeight) searcher
.createWeight(searcher.rewrite(scoringQuery), ScoreMode.COMPLETE, 1);

scoreFeatures(searcher,topN, modelWeight, firstPassResults, leaves, reranked);
scoreFeatures(searcher, modelWeight, firstPassResults, leaves, reranked);
// Must sort all documents that we reranked, and then select the top
Arrays.sort(reranked, scoreComparator);
return reranked;
Expand All @@ -153,8 +149,8 @@ protected static ScoreDoc[] getFirstPassDocsRanked(TopDocs firstPassTopDocs) {
return hits;
}

public void scoreFeatures(IndexSearcher indexSearcher,
int topN, LTRScoringQuery.ModelWeight modelWeight, ScoreDoc[] hits, List<LeafReaderContext> leaves,
private void scoreFeatures(IndexSearcher indexSearcher,
LTRScoringQuery.ModelWeight modelWeight, ScoreDoc[] hits, List<LeafReaderContext> leaves,
Comment on lines -156 to +153
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This here could be considered backwards incompatible (signature and behaviour change to a public method) but if #473 first marked the public method deprecated (in a prior Solr release) then it could be considered a backwards compatible change.

ScoreDoc[] reranked) throws IOException {

int readerUpto = -1;
Expand All @@ -178,16 +174,14 @@ public void scoreFeatures(IndexSearcher indexSearcher,
docBase = readerContext.docBase;
scorer = modelWeight.scorer(readerContext);
}
if (scoreSingleHit(topN, docBase, hitUpto, hit, docID, scorer, reranked)) {
logSingleHit(indexSearcher, modelWeight, hit.doc, scoringQuery);
}
reranked[hitUpto] = scoreSingleHit(docBase, hit, docID, scorer);
logSingleHit(indexSearcher, modelWeight, hit.doc, scoringQuery);
hitUpto++;
}
}

/**
* Call this method if the {@link #scoreSingleHit(int, int, int, ScoreDoc, int, org.apache.solr.ltr.LTRScoringQuery.ModelWeight.ModelScorer, ScoreDoc[])}
* method indicated that the document's feature info should be logged.
* Logs a document's feature info.
*/
protected static void logSingleHit(IndexSearcher indexSearcher, LTRScoringQuery.ModelWeight modelWeight, int docid, LTRScoringQuery scoringQuery) {
final FeatureLogger featureLogger = scoringQuery.getFeatureLogger();
Expand All @@ -196,11 +190,35 @@ protected static void logSingleHit(IndexSearcher indexSearcher, LTRScoringQuery.
}
}

/**
* Scores a single document and returns it.
*/
protected static ScoreDoc scoreSingleHit(int docBase, ScoreDoc hit, int docID, LTRScoringQuery.ModelWeight.ModelScorer scorer) throws IOException {
// Scorer for a LTRScoringQuery.ModelWeight should never be null since we always have to
// call score
// even if no feature scorers match, since a model might use that info to
// return a
// non-zero score. Same applies for the case of advancing a LTRScoringQuery.ModelWeight.ModelScorer
// past the target
// doc since the model algorithm still needs to compute a potentially
// non-zero score from blank features.
assert (scorer != null);
final int targetDoc = docID - docBase;
scorer.docID();
scorer.iterator().advance(targetDoc);

scorer.getDocInfo().setOriginalDocScore(hit.score);
hit.score = scorer.score();
return hit;
}

/**
* Scores a single document and returns true if the document's feature info should be logged via the
* {@link #logSingleHit(IndexSearcher, org.apache.solr.ltr.LTRScoringQuery.ModelWeight, int, LTRScoringQuery)}
* method. Feature info logging is only necessary for the topN documents.
* @deprecated From Solr 9.2.0 onwards this method will be removed.
*/
@Deprecated
Comment on lines +219 to +221
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is now unused but would be retained for (say) one Solr release for backwards compatibility.

protected static boolean scoreSingleHit(int topN, int docBase, int hitUpto, ScoreDoc hit, int docID, LTRScoringQuery.ModelWeight.ModelScorer scorer, ScoreDoc[] reranked) throws IOException {
// Scorer for a LTRScoringQuery.ModelWeight should never be null since we always have to
// call score
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,16 @@ public LTRInterleavingRescorer( Interleaving interleavingAlgorithm, LTRInterleav
}

/**
* rescores the documents:
* rescores all the documents:
*
* @param searcher
* current IndexSearcher
* @param firstPassTopDocs
* documents to rerank;
* @param topN
* documents to return;
*/
@Override
public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs,
int topN) throws IOException {
if ((topN == 0) || (firstPassTopDocs.scoreDocs.length == 0)) {
public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs) throws IOException {
if (firstPassTopDocs.scoreDocs.length == 0) {
return firstPassTopDocs;
}

Expand All @@ -75,9 +72,8 @@ public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs,
firstPassResults = new ScoreDoc[firstPassTopDocs.scoreDocs.length];
System.arraycopy(firstPassTopDocs.scoreDocs, 0, firstPassResults, 0, firstPassTopDocs.scoreDocs.length);
}
topN = Math.toIntExact(Math.min(topN, firstPassTopDocs.totalHits.value));

ScoreDoc[][] reRankedPerModel = rerank(searcher,topN,getFirstPassDocsRanked(firstPassTopDocs));
ScoreDoc[][] reRankedPerModel = rerank(searcher,getFirstPassDocsRanked(firstPassTopDocs));
if (originalRankingIndex != null) {
reRankedPerModel[originalRankingIndex] = firstPassResults;
}
Expand All @@ -91,8 +87,8 @@ public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs,
return new TopDocs(firstPassTopDocs.totalHits, interleavedResults);
}

private ScoreDoc[][] rerank(IndexSearcher searcher, int topN, ScoreDoc[] firstPassResults) throws IOException {
ScoreDoc[][] reRankedPerModel = new ScoreDoc[rerankingQueries.length][topN];
private ScoreDoc[][] rerank(IndexSearcher searcher, ScoreDoc[] firstPassResults) throws IOException {
ScoreDoc[][] reRankedPerModel = new ScoreDoc[rerankingQueries.length][firstPassResults.length];
final List<LeafReaderContext> leaves = searcher.getIndexReader().leaves();
LTRScoringQuery.ModelWeight[] modelWeights = new LTRScoringQuery.ModelWeight[rerankingQueries.length];
for (int i = 0; i < rerankingQueries.length; i++) {
Expand All @@ -101,7 +97,7 @@ private ScoreDoc[][] rerank(IndexSearcher searcher, int topN, ScoreDoc[] firstPa
.createWeight(searcher.rewrite(rerankingQueries[i]), ScoreMode.COMPLETE, 1);
}
}
scoreFeatures(searcher, topN, modelWeights, firstPassResults, leaves, reRankedPerModel);
scoreFeatures(searcher, modelWeights, firstPassResults, leaves, reRankedPerModel);

for (int i = 0; i < rerankingQueries.length; i++) {
if (originalRankingIndex == null || originalRankingIndex != i) {
Expand All @@ -112,8 +108,8 @@ private ScoreDoc[][] rerank(IndexSearcher searcher, int topN, ScoreDoc[] firstPa
return reRankedPerModel;
}

public void scoreFeatures(IndexSearcher indexSearcher,
int topN, LTRScoringQuery.ModelWeight[] modelWeights, ScoreDoc[] hits, List<LeafReaderContext> leaves,
private void scoreFeatures(IndexSearcher indexSearcher,
LTRScoringQuery.ModelWeight[] modelWeights, ScoreDoc[] hits, List<LeafReaderContext> leaves,
Comment on lines -115 to +112
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This here could be considered backwards incompatible (signature and behaviour change to a public method) but if #473 first marked the public method deprecated (in a prior Solr release) then it could be considered a backwards compatible change.

ScoreDoc[][] rerankedPerModel) throws IOException {

int readerUpto = -1;
Expand Down Expand Up @@ -143,9 +139,8 @@ public void scoreFeatures(IndexSearcher indexSearcher,
for (int i = 0; i < rerankingQueries.length; i++) {
if (modelWeights[i] != null) {
final ScoreDoc hit_i = new ScoreDoc(hit.doc, hit.score, hit.shardIndex);
if (scoreSingleHit(topN, docBase, hitUpto, hit_i, docID, scorers[i], rerankedPerModel[i])) {
logSingleHit(indexSearcher, modelWeights[i], hit_i.doc, rerankingQueries[i]);
}
rerankedPerModel[i][hitUpto] = scoreSingleHit(docBase, hit_i, docID, scorers[i]);
logSingleHit(indexSearcher, modelWeights[i], hit_i.doc, rerankingQueries[i]);
}
}
hitUpto++;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ public void testRescorer() throws Exception {
LTRScoringQuery ltrScoringQuery = new LTRScoringQuery(ltrScoringModel);
ltrScoringQuery.setRequest(solrQueryRequest);
final LTRRescorer rescorer = new LTRRescorer(ltrScoringQuery);
hits = rescorer.rescore(searcher, hits, 2);
assertEquals(2, hits.scoreDocs.length);
hits = rescorer.rescore(searcher, hits);

// rerank using the field finalScore
assertEquals("1", searcher.doc(hits.scoreDocs[0].doc).get("id"));
Expand Down Expand Up @@ -183,23 +184,24 @@ public void testDifferentTopN() throws IOException {
final LTRRescorer rescorer = new LTRRescorer(scoringQuery);

// rerank @ 0 should not change the order
hits = rescorer.rescore(searcher, hits, 0);
assertEquals("0", searcher.doc(hits.scoreDocs[0].doc).get("id"));
assertEquals("1", searcher.doc(hits.scoreDocs[1].doc).get("id"));
assertEquals("2", searcher.doc(hits.scoreDocs[2].doc).get("id"));
assertEquals("3", searcher.doc(hits.scoreDocs[3].doc).get("id"));
assertEquals("4", searcher.doc(hits.scoreDocs[4].doc).get("id"));
{
final TopDocs noHits = new TopDocs(hits.totalHits, new ScoreDoc[0]);
final TopDocs noHitsRescored = rescorer.rescore(searcher, noHits);
assertEquals(0, noHitsRescored.scoreDocs.length);
}

// test rerank with different topN cuts

// cap firstPassTopDocs length at topN
for (int topN = 1; topN <= 5; topN++) {
log.info("rerank {} documents ", topN);
log.info("rerank {} documents, return {} documents", topN, topN);
hits = searcher.search(bqBuilder.build(), 10);

final ScoreDoc[] slice = new ScoreDoc[topN];
System.arraycopy(hits.scoreDocs, 0, slice, 0, topN);
hits = new TopDocs(hits.totalHits, slice);
hits = rescorer.rescore(searcher, hits, topN);
hits = rescorer.rescore(searcher, hits);
assertEquals(topN, hits.scoreDocs.length);
for (int i = topN - 1, j = 0; i >= 0; i--, j++) {
if (log.isInfoEnabled()) {
log.info("doc {} in pos {}", searcher.doc(hits.scoreDocs[j].doc)
Expand All @@ -211,6 +213,17 @@ public void testDifferentTopN() throws IOException {

}
}

// use full firstPassTopDocs (possibly higher than topN)
for (int topN = 0; topN <= 5; topN++) {
final TopDocs allHits = searcher.search(bqBuilder.build(), 10);
log.info("rerank {} documents, return {} documents", allHits.scoreDocs.length, topN);

final int topNN = topN;
expectThrows(UnsupportedOperationException.class, () -> {
rescorer.rescore(searcher, allHits, topNN);
});
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,12 @@ public TopDocs topDocs(int start, int howMany) {

mainDocs.scoreDocs = reRankScoreDocs;

TopDocs rescoredDocs = reRankQueryRescorer
.rescore(searcher, mainDocs, mainDocs.scoreDocs.length);
final TopDocs rescoredDocs;
if (reRankQueryRescorer instanceof ReRankRescorer) {
rescoredDocs = ((ReRankRescorer) reRankQueryRescorer).rescore(searcher, mainDocs);
} else {
rescoredDocs = reRankQueryRescorer.rescore(searcher, mainDocs, mainDocs.scoreDocs.length);
}

//Lower howMany to return if we've collected fewer documents.
howMany = Math.min(howMany, mainScoreDocs.length);
Expand Down
53 changes: 53 additions & 0 deletions solr/core/src/java/org/apache/solr/search/ReRankRescorer.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.search;

import java.io.IOException;

import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Rescorer;
import org.apache.lucene.search.TopDocs;

/**
* This class is a variant of Lucene's {@link Rescorer} class for use by the
* {@link ReRankCollector} to rescore all results in a {@link TopDocs} object
* from an original query.
*/
public abstract class ReRankRescorer extends Rescorer {

/**
* Rescore an initial first-pass {@link TopDocs}.
*
* @param searcher {@link IndexSearcher} used to produce the first pass topDocs
* @param firstPassTopDocs Hits from the first pass search that are to be rescored.
* It's very important that these hits were produced by the provided searcher;
* otherwise the doc IDs will not match!
*/
public abstract TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs)
throws IOException;

/**
* Throws an {@link UnsupportedOperationException} exception.
* Use {@link #rescore(IndexSearcher, TopDocs)} instead.
*/
final public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs, int topN)
throws IOException
{
throw new UnsupportedOperationException();
}

}