Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SOLR-15873: factor out a ReRankRescorer class, simplify LTR[Interleaving]Rescorer code #478

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 103 additions & 4 deletions solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@
import org.apache.lucene.index.ReaderUtil;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Rescorer;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.Weight;
import org.apache.solr.ltr.interleaving.OriginalRankingLTRScoringQuery;
import org.apache.solr.search.ReRankRescorer;
import org.apache.solr.search.SolrIndexSearcher;


Expand All @@ -41,7 +41,7 @@
* new score to each document. The top documents will be resorted based on the
* new score.
* */
public class LTRRescorer extends Rescorer {
public class LTRRescorer extends ReRankRescorer {

final private LTRScoringQuery scoringQuery;

Expand Down Expand Up @@ -118,7 +118,11 @@ protected static void heapify(ScoreDoc[] hits, int size) {
* documents to rerank;
* @param topN
* documents to return;

* @deprecated Use {@link #rescore(IndexSearcher, TopDocs)} instead.
* From Solr 9.1.0 onwards this method will be removed.
*/
@Deprecated
@Override
public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs,
int topN) throws IOException {
Expand All @@ -133,6 +137,31 @@ public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs,
return new TopDocs(firstPassTopDocs.totalHits, reranked);
}

/**
* rescores all the documents:
*
* @param searcher
* current IndexSearcher
* @param firstPassTopDocs
* documents to rerank;
*/
@Override
public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs) throws IOException {
if (firstPassTopDocs.scoreDocs.length == 0) {
return firstPassTopDocs;
}
final ScoreDoc[] firstPassResults = getFirstPassDocsRanked(firstPassTopDocs);

final ScoreDoc[] reranked = rerank(searcher, firstPassResults);

return new TopDocs(firstPassTopDocs.totalHits, reranked);
}

/**
* @deprecated Use {@link #rerank(IndexSearcher, ScoreDoc[])} instead.
* From Solr 9.1.0 onwards this method will be removed.
*/
@Deprecated
private ScoreDoc[] rerank(IndexSearcher searcher, int topN, ScoreDoc[] firstPassResults) throws IOException {
final ScoreDoc[] reranked = new ScoreDoc[topN];
final List<LeafReaderContext> leaves = searcher.getIndexReader().leaves();
Expand All @@ -145,6 +174,18 @@ private ScoreDoc[] rerank(IndexSearcher searcher, int topN, ScoreDoc[] firstPass
return reranked;
}

private ScoreDoc[] rerank(IndexSearcher searcher, ScoreDoc[] firstPassResults) throws IOException {
final ScoreDoc[] reranked = new ScoreDoc[firstPassResults.length];
final List<LeafReaderContext> leaves = searcher.getIndexReader().leaves();
final LTRScoringQuery.ModelWeight modelWeight = (LTRScoringQuery.ModelWeight) searcher
.createWeight(searcher.rewrite(scoringQuery), ScoreMode.COMPLETE, 1);

scoreFeatures(searcher, modelWeight, firstPassResults, leaves, reranked);
// Must sort all documents that we reranked, and then select the top
Arrays.sort(reranked, scoreComparator);
return reranked;
}

protected static ScoreDoc[] getFirstPassDocsRanked(TopDocs firstPassTopDocs) {
final ScoreDoc[] hits = firstPassTopDocs.scoreDocs;
Arrays.sort(hits, docComparator);
Expand All @@ -153,6 +194,10 @@ protected static ScoreDoc[] getFirstPassDocsRanked(TopDocs firstPassTopDocs) {
return hits;
}

/**
* @deprecated From Solr 9.1.0 onwards this method will be removed.
*/
@Deprecated
public void scoreFeatures(IndexSearcher indexSearcher,
int topN, LTRScoringQuery.ModelWeight modelWeight, ScoreDoc[] hits, List<LeafReaderContext> leaves,
ScoreDoc[] reranked) throws IOException {
Expand Down Expand Up @@ -185,9 +230,39 @@ public void scoreFeatures(IndexSearcher indexSearcher,
}
}

private void scoreFeatures(IndexSearcher indexSearcher,
LTRScoringQuery.ModelWeight modelWeight, ScoreDoc[] hits, List<LeafReaderContext> leaves,
ScoreDoc[] reranked) throws IOException {

int readerUpto = -1;
int endDoc = 0;
int docBase = 0;

LTRScoringQuery.ModelWeight.ModelScorer scorer = null;
int hitUpto = 0;

while (hitUpto < hits.length) {
final ScoreDoc hit = hits[hitUpto];
final int docID = hit.doc;
LeafReaderContext readerContext = null;
while (docID >= endDoc) {
readerUpto++;
readerContext = leaves.get(readerUpto);
endDoc = readerContext.docBase + readerContext.reader().maxDoc();
}
// We advanced to another segment
if (readerContext != null) {
docBase = readerContext.docBase;
scorer = modelWeight.scorer(readerContext);
}
reranked[hitUpto] = scoreSingleHit(docBase, hit, docID, scorer);
logSingleHit(indexSearcher, modelWeight, hit.doc, scoringQuery);
hitUpto++;
}
}

/**
* Call this method if the {@link #scoreSingleHit(int, int, int, ScoreDoc, int, org.apache.solr.ltr.LTRScoringQuery.ModelWeight.ModelScorer, ScoreDoc[])}
* method indicated that the document's feature info should be logged.
* Logs a document's feature info.
*/
protected static void logSingleHit(IndexSearcher indexSearcher, LTRScoringQuery.ModelWeight modelWeight, int docid, LTRScoringQuery scoringQuery) {
final FeatureLogger featureLogger = scoringQuery.getFeatureLogger();
Expand All @@ -200,7 +275,9 @@ protected static void logSingleHit(IndexSearcher indexSearcher, LTRScoringQuery.
* Scores a single document and returns true if the document's feature info should be logged via the
* {@link #logSingleHit(IndexSearcher, org.apache.solr.ltr.LTRScoringQuery.ModelWeight, int, LTRScoringQuery)}
* method. Feature info logging is only necessary for the topN documents.
* @deprecated From Solr 9.1.0 onwards this method will be removed.
*/
@Deprecated
protected static boolean scoreSingleHit(int topN, int docBase, int hitUpto, ScoreDoc hit, int docID, LTRScoringQuery.ModelWeight.ModelScorer scorer, ScoreDoc[] reranked) throws IOException {
// Scorer for a LTRScoringQuery.ModelWeight should never be null since we always have to
// call score
Expand Down Expand Up @@ -243,6 +320,28 @@ protected static boolean scoreSingleHit(int topN, int docBase, int hitUpto, Scor
return logHit;
}

/**
* Scores a single document and returns it.
*/
protected static ScoreDoc scoreSingleHit(int docBase, ScoreDoc hit, int docID, LTRScoringQuery.ModelWeight.ModelScorer scorer) throws IOException {
// Scorer for a LTRScoringQuery.ModelWeight should never be null since we always have to
// call score
// even if no feature scorers match, since a model might use that info to
// return a
// non-zero score. Same applies for the case of advancing a LTRScoringQuery.ModelWeight.ModelScorer
// past the target
// doc since the model algorithm still needs to compute a potentially
// non-zero score from blank features.
assert (scorer != null);
final int targetDoc = docID - docBase;
scorer.docID();
scorer.iterator().advance(targetDoc);

scorer.getDocInfo().setOriginalDocScore(hit.score);
hit.score = scorer.score();
return hit;
}

@Override
public Explanation explain(IndexSearcher searcher,
Explanation firstPassExplanation, int docID) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,11 @@ public LTRInterleavingRescorer( Interleaving interleavingAlgorithm, LTRInterleav
* documents to rerank;
* @param topN
* documents to return;

* @deprecated Use {@link #rescore(IndexSearcher, TopDocs)} instead.
* From Solr 9.1.0 onwards this method will be removed.
*/
@Deprecated
@Override
public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs,
int topN) throws IOException {
Expand Down Expand Up @@ -91,6 +95,45 @@ public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs,
return new TopDocs(firstPassTopDocs.totalHits, interleavedResults);
}

/**
* rescores all the documents:
*
* @param searcher
* current IndexSearcher
* @param firstPassTopDocs
* documents to rerank;
*/
@Override
public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs) throws IOException {
if (firstPassTopDocs.scoreDocs.length == 0) {
return firstPassTopDocs;
}

ScoreDoc[] firstPassResults = null;
if(originalRankingIndex != null) {
firstPassResults = new ScoreDoc[firstPassTopDocs.scoreDocs.length];
System.arraycopy(firstPassTopDocs.scoreDocs, 0, firstPassResults, 0, firstPassTopDocs.scoreDocs.length);
}

ScoreDoc[][] reRankedPerModel = rerank(searcher,getFirstPassDocsRanked(firstPassTopDocs));
if (originalRankingIndex != null) {
reRankedPerModel[originalRankingIndex] = firstPassResults;
}
InterleavingResult interleaved = interleavingAlgorithm.interleave(reRankedPerModel[0], reRankedPerModel[1]);
ScoreDoc[] interleavedResults = interleaved.getInterleavedResults();

ArrayList<Set<Integer>> interleavingPicks = interleaved.getInterleavingPicks();
rerankingQueries[0].setPickedInterleavingDocIds(interleavingPicks.get(0));
rerankingQueries[1].setPickedInterleavingDocIds(interleavingPicks.get(1));

return new TopDocs(firstPassTopDocs.totalHits, interleavedResults);
}

/**
* @deprecated Use {@link #rerank(IndexSearcher, ScoreDoc[])} instead.
* From Solr 9.1.0 onwards this method will be removed.
*/
@Deprecated
private ScoreDoc[][] rerank(IndexSearcher searcher, int topN, ScoreDoc[] firstPassResults) throws IOException {
ScoreDoc[][] reRankedPerModel = new ScoreDoc[rerankingQueries.length][topN];
final List<LeafReaderContext> leaves = searcher.getIndexReader().leaves();
Expand All @@ -112,6 +155,31 @@ private ScoreDoc[][] rerank(IndexSearcher searcher, int topN, ScoreDoc[] firstPa
return reRankedPerModel;
}

private ScoreDoc[][] rerank(IndexSearcher searcher, ScoreDoc[] firstPassResults) throws IOException {
ScoreDoc[][] reRankedPerModel = new ScoreDoc[rerankingQueries.length][firstPassResults.length];
final List<LeafReaderContext> leaves = searcher.getIndexReader().leaves();
LTRScoringQuery.ModelWeight[] modelWeights = new LTRScoringQuery.ModelWeight[rerankingQueries.length];
for (int i = 0; i < rerankingQueries.length; i++) {
if (originalRankingIndex == null || originalRankingIndex != i) {
modelWeights[i] = (LTRScoringQuery.ModelWeight) searcher
.createWeight(searcher.rewrite(rerankingQueries[i]), ScoreMode.COMPLETE, 1);
}
}
scoreFeatures(searcher, modelWeights, firstPassResults, leaves, reRankedPerModel);

for (int i = 0; i < rerankingQueries.length; i++) {
if (originalRankingIndex == null || originalRankingIndex != i) {
Arrays.sort(reRankedPerModel[i], scoreComparator);
}
}

return reRankedPerModel;
}

/**
* @deprecated From Solr 9.1.0 onwards this method will be removed.
*/
@Deprecated
public void scoreFeatures(IndexSearcher indexSearcher,
int topN, LTRScoringQuery.ModelWeight[] modelWeights, ScoreDoc[] hits, List<LeafReaderContext> leaves,
ScoreDoc[][] rerankedPerModel) throws IOException {
Expand Down Expand Up @@ -153,6 +221,46 @@ public void scoreFeatures(IndexSearcher indexSearcher,

}

private void scoreFeatures(IndexSearcher indexSearcher,
LTRScoringQuery.ModelWeight[] modelWeights, ScoreDoc[] hits, List<LeafReaderContext> leaves,
ScoreDoc[][] rerankedPerModel) throws IOException {

int readerUpto = -1;
int endDoc = 0;
int docBase = 0;
int hitUpto = 0;
LTRScoringQuery.ModelWeight.ModelScorer[] scorers = new LTRScoringQuery.ModelWeight.ModelScorer[rerankingQueries.length];
while (hitUpto < hits.length) {
final ScoreDoc hit = hits[hitUpto];
final int docID = hit.doc;
LeafReaderContext readerContext = null;
while (docID >= endDoc) {
readerUpto++;
readerContext = leaves.get(readerUpto);
endDoc = readerContext.docBase + readerContext.reader().maxDoc();
}

// We advanced to another segment
if (readerContext != null) {
docBase = readerContext.docBase;
for (int i = 0; i < modelWeights.length; i++) {
if (modelWeights[i] != null) {
scorers[i] = modelWeights[i].scorer(readerContext);
}
}
}
for (int i = 0; i < rerankingQueries.length; i++) {
if (modelWeights[i] != null) {
final ScoreDoc hit_i = new ScoreDoc(hit.doc, hit.score, hit.shardIndex);
rerankedPerModel[i][hitUpto] = scoreSingleHit(docBase, hit_i, docID, scorers[i]);
logSingleHit(indexSearcher, modelWeights[i], hit_i.doc, rerankingQueries[i]);
}
}
hitUpto++;
}

}

@Override
public Explanation explain(IndexSearcher searcher,
Explanation firstPassExplanation, int docID) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ public void testRescorer() throws Exception {
LTRScoringQuery ltrScoringQuery = new LTRScoringQuery(ltrScoringModel);
ltrScoringQuery.setRequest(solrQueryRequest);
final LTRRescorer rescorer = new LTRRescorer(ltrScoringQuery);
hits = rescorer.rescore(searcher, hits, 2);
hits = random().nextBoolean() ? rescorer.rescore(searcher, hits) : rescorer.rescore(searcher, hits, 2);

// rerank using the field finalScore
assertEquals("1", searcher.doc(hits.scoreDocs[0].doc).get("id"));
Expand Down Expand Up @@ -189,17 +189,24 @@ public void testDifferentTopN() throws IOException {
assertEquals("2", searcher.doc(hits.scoreDocs[2].doc).get("id"));
assertEquals("3", searcher.doc(hits.scoreDocs[3].doc).get("id"));
assertEquals("4", searcher.doc(hits.scoreDocs[4].doc).get("id"));
{
final TopDocs noHits = new TopDocs(hits.totalHits, new ScoreDoc[0]);
final TopDocs noHitsRescored = rescorer.rescore(searcher, noHits);
assertEquals(0, noHitsRescored.scoreDocs.length);
}

// test rerank with different topN cuts

// cap firstPassTopDocs length at topN
for (int topN = 1; topN <= 5; topN++) {
log.info("rerank {} documents ", topN);
log.info("rerank {} documents, return {} documents", topN, topN);
hits = searcher.search(bqBuilder.build(), 10);

final ScoreDoc[] slice = new ScoreDoc[topN];
System.arraycopy(hits.scoreDocs, 0, slice, 0, topN);
hits = new TopDocs(hits.totalHits, slice);
hits = rescorer.rescore(searcher, hits, topN);
hits = random().nextBoolean() ? rescorer.rescore(searcher, hits) : rescorer.rescore(searcher, hits, topN);
assertEquals(topN, hits.scoreDocs.length);
for (int i = topN - 1, j = 0; i >= 0; i--, j++) {
if (log.isInfoEnabled()) {
log.info("doc {} in pos {}", searcher.doc(hits.scoreDocs[j].doc)
Expand All @@ -211,6 +218,25 @@ public void testDifferentTopN() throws IOException {

}
}

// use full firstPassTopDocs (possibly higher than topN)
for (int topN = 1; topN <= 5; topN++) {
final TopDocs allHits = searcher.search(bqBuilder.build(), 10);
log.info("rerank {} documents, return {} documents", allHits.scoreDocs.length, topN);

TopDocs rescoredHits = rescorer.rescore(searcher, allHits, topN);
assertEquals(topN, rescoredHits.scoreDocs.length);
for (int i = allHits.scoreDocs.length-1, j = 0; i >= 0 && j < topN; i--, j++) {
if (log.isInfoEnabled()) {
log.info("doc {} in pos {}", searcher.doc(rescoredHits.scoreDocs[j].doc)
.get("id"), j);
}
assertEquals(i,
Integer.parseInt(searcher.doc(rescoredHits.scoreDocs[j].doc).get("id")));
assertEquals((i + 1) * features.size()*featureWeight, rescoredHits.scoreDocs[j].score, 0.00001);

}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,12 @@ public TopDocs topDocs(int start, int howMany) {

mainDocs.scoreDocs = reRankScoreDocs;

TopDocs rescoredDocs = reRankQueryRescorer
.rescore(searcher, mainDocs, mainDocs.scoreDocs.length);
final TopDocs rescoredDocs;
if (reRankQueryRescorer instanceof ReRankRescorer) {
rescoredDocs = ((ReRankRescorer) reRankQueryRescorer).rescore(searcher, mainDocs);
} else {
rescoredDocs = reRankQueryRescorer.rescore(searcher, mainDocs, mainDocs.scoreDocs.length);
}

//Lower howMany to return if we've collected fewer documents.
howMany = Math.min(howMany, mainScoreDocs.length);
Expand Down
Loading