diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java index 4909a6f8a96..0f1fead8ac3 100644 --- a/solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java @@ -25,13 +25,13 @@ import org.apache.lucene.index.ReaderUtil; import org.apache.lucene.search.Explanation; import org.apache.lucene.search.IndexSearcher; -import org.apache.lucene.search.Rescorer; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.Weight; import org.apache.solr.ltr.interleaving.OriginalRankingLTRScoringQuery; +import org.apache.solr.search.ReRankRescorer; import org.apache.solr.search.SolrIndexSearcher; @@ -41,7 +41,7 @@ * new score to each document. The top documents will be resorted based on the * new score. * */ -public class LTRRescorer extends Rescorer { +public class LTRRescorer extends ReRankRescorer { final private LTRScoringQuery scoringQuery; @@ -110,36 +110,32 @@ protected static void heapify(ScoreDoc[] hits, int size) { } /** - * rescores the documents: + * rescores all the documents: * * @param searcher * current IndexSearcher * @param firstPassTopDocs * documents to rerank; - * @param topN - * documents to return; */ @Override - public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs, - int topN) throws IOException { - if ((topN == 0) || (firstPassTopDocs.scoreDocs.length == 0)) { + public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs) throws IOException { + if (firstPassTopDocs.scoreDocs.length == 0) { return firstPassTopDocs; } final ScoreDoc[] firstPassResults = getFirstPassDocsRanked(firstPassTopDocs); - topN = Math.toIntExact(Math.min(topN, firstPassTopDocs.totalHits.value)); - final ScoreDoc[] reranked = rerank(searcher, topN, firstPassResults); + final ScoreDoc[] reranked = rerank(searcher, firstPassResults); return new TopDocs(firstPassTopDocs.totalHits, reranked); } - private ScoreDoc[] rerank(IndexSearcher searcher, int topN, ScoreDoc[] firstPassResults) throws IOException { - final ScoreDoc[] reranked = new ScoreDoc[topN]; + private ScoreDoc[] rerank(IndexSearcher searcher, ScoreDoc[] firstPassResults) throws IOException { + final ScoreDoc[] reranked = new ScoreDoc[firstPassResults.length]; final List<LeafReaderContext> leaves = searcher.getIndexReader().leaves(); final LTRScoringQuery.ModelWeight modelWeight = (LTRScoringQuery.ModelWeight) searcher .createWeight(searcher.rewrite(scoringQuery), ScoreMode.COMPLETE, 1); - scoreFeatures(searcher,topN, modelWeight, firstPassResults, leaves, reranked); + scoreFeatures(searcher, modelWeight, firstPassResults, leaves, reranked); // Must sort all documents that we reranked, and then select the top Arrays.sort(reranked, scoreComparator); return reranked; @@ -153,8 +149,8 @@ protected static ScoreDoc[] getFirstPassDocsRanked(TopDocs firstPassTopDocs) { return hits; } - public void scoreFeatures(IndexSearcher indexSearcher, - int topN, LTRScoringQuery.ModelWeight modelWeight, ScoreDoc[] hits, List<LeafReaderContext> leaves, + private void scoreFeatures(IndexSearcher indexSearcher, + LTRScoringQuery.ModelWeight modelWeight, ScoreDoc[] hits, List<LeafReaderContext> leaves, ScoreDoc[] reranked) throws IOException { int readerUpto = -1; @@ -178,16 +174,14 @@ public void scoreFeatures(IndexSearcher indexSearcher, docBase = readerContext.docBase; scorer = modelWeight.scorer(readerContext); } - if (scoreSingleHit(topN, docBase, hitUpto, hit, docID, scorer, reranked)) { - logSingleHit(indexSearcher, modelWeight, hit.doc, scoringQuery); - } + reranked[hitUpto] = scoreSingleHit(docBase, hit, docID, scorer); + logSingleHit(indexSearcher, modelWeight, hit.doc, scoringQuery); hitUpto++; } } /** - * Call this method if the {@link #scoreSingleHit(int, int, int, ScoreDoc, int, org.apache.solr.ltr.LTRScoringQuery.ModelWeight.ModelScorer, ScoreDoc[])} - * method indicated that the document's feature info should be logged. + * Logs a document's feature info. */ protected static void logSingleHit(IndexSearcher indexSearcher, LTRScoringQuery.ModelWeight modelWeight, int docid, LTRScoringQuery scoringQuery) { final FeatureLogger featureLogger = scoringQuery.getFeatureLogger(); @@ -196,11 +190,35 @@ protected static void logSingleHit(IndexSearcher indexSearcher, LTRScoringQuery. } } + /** + * Scores a single document and returns it. + */ + protected static ScoreDoc scoreSingleHit(int docBase, ScoreDoc hit, int docID, LTRScoringQuery.ModelWeight.ModelScorer scorer) throws IOException { + // Scorer for a LTRScoringQuery.ModelWeight should never be null since we always have to + // call score + // even if no feature scorers match, since a model might use that info to + // return a + // non-zero score. Same applies for the case of advancing a LTRScoringQuery.ModelWeight.ModelScorer + // past the target + // doc since the model algorithm still needs to compute a potentially + // non-zero score from blank features. + assert (scorer != null); + final int targetDoc = docID - docBase; + scorer.docID(); + scorer.iterator().advance(targetDoc); + + scorer.getDocInfo().setOriginalDocScore(hit.score); + hit.score = scorer.score(); + return hit; + } + /** * Scores a single document and returns true if the document's feature info should be logged via the * {@link #logSingleHit(IndexSearcher, org.apache.solr.ltr.LTRScoringQuery.ModelWeight, int, LTRScoringQuery)} * method. Feature info logging is only necessary for the topN documents. + * @deprecated From Solr 9.2.0 onwards this method will be removed. */ + @Deprecated protected static boolean scoreSingleHit(int topN, int docBase, int hitUpto, ScoreDoc hit, int docID, LTRScoringQuery.ModelWeight.ModelScorer scorer, ScoreDoc[] reranked) throws IOException { // Scorer for a LTRScoringQuery.ModelWeight should never be null since we always have to // call score diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/interleaving/LTRInterleavingRescorer.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/interleaving/LTRInterleavingRescorer.java index 799f4d9e36a..17df60270fb 100644 --- a/solr/contrib/ltr/src/java/org/apache/solr/ltr/interleaving/LTRInterleavingRescorer.java +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/interleaving/LTRInterleavingRescorer.java @@ -54,19 +54,16 @@ public LTRInterleavingRescorer( Interleaving interleavingAlgorithm, LTRInterleav } /** - * rescores the documents: + * rescores all the documents: * * @param searcher * current IndexSearcher * @param firstPassTopDocs * documents to rerank; - * @param topN - * documents to return; */ @Override - public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs, - int topN) throws IOException { - if ((topN == 0) || (firstPassTopDocs.scoreDocs.length == 0)) { + public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs) throws IOException { + if (firstPassTopDocs.scoreDocs.length == 0) { return firstPassTopDocs; } @@ -75,9 +72,8 @@ public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs, firstPassResults = new ScoreDoc[firstPassTopDocs.scoreDocs.length]; System.arraycopy(firstPassTopDocs.scoreDocs, 0, firstPassResults, 0, firstPassTopDocs.scoreDocs.length); } - topN = Math.toIntExact(Math.min(topN, firstPassTopDocs.totalHits.value)); - ScoreDoc[][] reRankedPerModel = rerank(searcher,topN,getFirstPassDocsRanked(firstPassTopDocs)); + ScoreDoc[][] reRankedPerModel = rerank(searcher,getFirstPassDocsRanked(firstPassTopDocs)); if (originalRankingIndex != null) { reRankedPerModel[originalRankingIndex] = firstPassResults; } @@ -91,8 +87,8 @@ public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs, return new TopDocs(firstPassTopDocs.totalHits, interleavedResults); } - private ScoreDoc[][] rerank(IndexSearcher searcher, int topN, ScoreDoc[] firstPassResults) throws IOException { - ScoreDoc[][] reRankedPerModel = new ScoreDoc[rerankingQueries.length][topN]; + private ScoreDoc[][] rerank(IndexSearcher searcher, ScoreDoc[] firstPassResults) throws IOException { + ScoreDoc[][] reRankedPerModel = new ScoreDoc[rerankingQueries.length][firstPassResults.length]; final List<LeafReaderContext> leaves = searcher.getIndexReader().leaves(); LTRScoringQuery.ModelWeight[] modelWeights = new LTRScoringQuery.ModelWeight[rerankingQueries.length]; for (int i = 0; i < rerankingQueries.length; i++) { @@ -101,7 +97,7 @@ private ScoreDoc[][] rerank(IndexSearcher searcher, int topN, ScoreDoc[] firstPa .createWeight(searcher.rewrite(rerankingQueries[i]), ScoreMode.COMPLETE, 1); } } - scoreFeatures(searcher, topN, modelWeights, firstPassResults, leaves, reRankedPerModel); + scoreFeatures(searcher, modelWeights, firstPassResults, leaves, reRankedPerModel); for (int i = 0; i < rerankingQueries.length; i++) { if (originalRankingIndex == null || originalRankingIndex != i) { @@ -112,8 +108,8 @@ private ScoreDoc[][] rerank(IndexSearcher searcher, int topN, ScoreDoc[] firstPa return reRankedPerModel; } - public void scoreFeatures(IndexSearcher indexSearcher, - int topN, LTRScoringQuery.ModelWeight[] modelWeights, ScoreDoc[] hits, List<LeafReaderContext> leaves, + private void scoreFeatures(IndexSearcher indexSearcher, + LTRScoringQuery.ModelWeight[] modelWeights, ScoreDoc[] hits, List<LeafReaderContext> leaves, ScoreDoc[][] rerankedPerModel) throws IOException { int readerUpto = -1; @@ -143,9 +139,8 @@ public void scoreFeatures(IndexSearcher indexSearcher, for (int i = 0; i < rerankingQueries.length; i++) { if (modelWeights[i] != null) { final ScoreDoc hit_i = new ScoreDoc(hit.doc, hit.score, hit.shardIndex); - if (scoreSingleHit(topN, docBase, hitUpto, hit_i, docID, scorers[i], rerankedPerModel[i])) { - logSingleHit(indexSearcher, modelWeights[i], hit_i.doc, rerankingQueries[i]); - } + rerankedPerModel[i][hitUpto] = scoreSingleHit(docBase, hit_i, docID, scorers[i]); + logSingleHit(indexSearcher, modelWeights[i], hit_i.doc, rerankingQueries[i]); } } hitUpto++; diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRReRankingPipeline.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRReRankingPipeline.java index e8a69422ae0..c1ee807aeab 100644 --- a/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRReRankingPipeline.java +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRReRankingPipeline.java @@ -132,7 +132,8 @@ public void testRescorer() throws Exception { LTRScoringQuery ltrScoringQuery = new LTRScoringQuery(ltrScoringModel); ltrScoringQuery.setRequest(solrQueryRequest); final LTRRescorer rescorer = new LTRRescorer(ltrScoringQuery); - hits = rescorer.rescore(searcher, hits, 2); + assertEquals(2, hits.scoreDocs.length); + hits = rescorer.rescore(searcher, hits); // rerank using the field finalScore assertEquals("1", searcher.doc(hits.scoreDocs[0].doc).get("id")); @@ -183,23 +184,24 @@ public void testDifferentTopN() throws IOException { final LTRRescorer rescorer = new LTRRescorer(scoringQuery); // rerank @ 0 should not change the order - hits = rescorer.rescore(searcher, hits, 0); - assertEquals("0", searcher.doc(hits.scoreDocs[0].doc).get("id")); - assertEquals("1", searcher.doc(hits.scoreDocs[1].doc).get("id")); - assertEquals("2", searcher.doc(hits.scoreDocs[2].doc).get("id")); - assertEquals("3", searcher.doc(hits.scoreDocs[3].doc).get("id")); - assertEquals("4", searcher.doc(hits.scoreDocs[4].doc).get("id")); + { + final TopDocs noHits = new TopDocs(hits.totalHits, new ScoreDoc[0]); + final TopDocs noHitsRescored = rescorer.rescore(searcher, noHits); + assertEquals(0, noHitsRescored.scoreDocs.length); + } // test rerank with different topN cuts + // cap firstPassTopDocs length at topN for (int topN = 1; topN <= 5; topN++) { - log.info("rerank {} documents ", topN); + log.info("rerank {} documents, return {} documents", topN, topN); hits = searcher.search(bqBuilder.build(), 10); final ScoreDoc[] slice = new ScoreDoc[topN]; System.arraycopy(hits.scoreDocs, 0, slice, 0, topN); hits = new TopDocs(hits.totalHits, slice); - hits = rescorer.rescore(searcher, hits, topN); + hits = rescorer.rescore(searcher, hits); + assertEquals(topN, hits.scoreDocs.length); for (int i = topN - 1, j = 0; i >= 0; i--, j++) { if (log.isInfoEnabled()) { log.info("doc {} in pos {}", searcher.doc(hits.scoreDocs[j].doc) @@ -211,6 +213,17 @@ public void testDifferentTopN() throws IOException { } } + + // use full firstPassTopDocs (possibly higher than topN) + for (int topN = 0; topN <= 5; topN++) { + final TopDocs allHits = searcher.search(bqBuilder.build(), 10); + log.info("rerank {} documents, return {} documents", allHits.scoreDocs.length, topN); + + final int topNN = topN; + expectThrows(UnsupportedOperationException.class, () -> { + rescorer.rescore(searcher, allHits, topNN); + }); + } } } diff --git a/solr/core/src/java/org/apache/solr/search/ReRankCollector.java b/solr/core/src/java/org/apache/solr/search/ReRankCollector.java index 11108e74e50..42f05917da1 100644 --- a/solr/core/src/java/org/apache/solr/search/ReRankCollector.java +++ b/solr/core/src/java/org/apache/solr/search/ReRankCollector.java @@ -112,8 +112,12 @@ public TopDocs topDocs(int start, int howMany) { mainDocs.scoreDocs = reRankScoreDocs; - TopDocs rescoredDocs = reRankQueryRescorer - .rescore(searcher, mainDocs, mainDocs.scoreDocs.length); + final TopDocs rescoredDocs; + if (reRankQueryRescorer instanceof ReRankRescorer) { + rescoredDocs = ((ReRankRescorer) reRankQueryRescorer).rescore(searcher, mainDocs); + } else { + rescoredDocs = reRankQueryRescorer.rescore(searcher, mainDocs, mainDocs.scoreDocs.length); + } //Lower howMany to return if we've collected fewer documents. howMany = Math.min(howMany, mainScoreDocs.length); diff --git a/solr/core/src/java/org/apache/solr/search/ReRankRescorer.java b/solr/core/src/java/org/apache/solr/search/ReRankRescorer.java new file mode 100644 index 00000000000..05cf4418dc8 --- /dev/null +++ b/solr/core/src/java/org/apache/solr/search/ReRankRescorer.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.search; + +import java.io.IOException; + +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Rescorer; +import org.apache.lucene.search.TopDocs; + +/** + * This class is a variant of Lucene's {@link Rescorer} class for use by the + * {@link ReRankCollector} to rescore all results in a {@link TopDocs} object + * from an original query. + */ +public abstract class ReRankRescorer extends Rescorer { + + /** + * Rescore an initial first-pass {@link TopDocs}. + * + * @param searcher {@link IndexSearcher} used to produce the first pass topDocs + * @param firstPassTopDocs Hits from the first pass search that are to be rescored. + * It's very important that these hits were produced by the provided searcher; + * otherwise the doc IDs will not match! + */ + public abstract TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs) + throws IOException; + + /** + * Throws an {@link UnsupportedOperationException} exception. + * Use {@link #rescore(IndexSearcher, TopDocs)} instead. + */ + final public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs, int topN) + throws IOException + { + throw new UnsupportedOperationException(); + } + +}