diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java index 4909a6f8a96..9ea78081503 100644 --- a/solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java @@ -25,13 +25,13 @@ import org.apache.lucene.index.ReaderUtil; import org.apache.lucene.search.Explanation; import org.apache.lucene.search.IndexSearcher; -import org.apache.lucene.search.Rescorer; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.Weight; import org.apache.solr.ltr.interleaving.OriginalRankingLTRScoringQuery; +import org.apache.solr.search.ReRankRescorer; import org.apache.solr.search.SolrIndexSearcher; @@ -41,7 +41,7 @@ * new score to each document. The top documents will be resorted based on the * new score. * */ -public class LTRRescorer extends Rescorer { +public class LTRRescorer extends ReRankRescorer { final private LTRScoringQuery scoringQuery; @@ -118,7 +118,11 @@ protected static void heapify(ScoreDoc[] hits, int size) { * documents to rerank; * @param topN * documents to return; + + * @deprecated Use {@link #rescore(IndexSearcher, TopDocs)} instead. + * From Solr 9.1.0 onwards this method will be removed. */ + @Deprecated @Override public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs, int topN) throws IOException { @@ -133,6 +137,31 @@ public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs, return new TopDocs(firstPassTopDocs.totalHits, reranked); } + /** + * rescores all the documents: + * + * @param searcher + * current IndexSearcher + * @param firstPassTopDocs + * documents to rerank; + */ + @Override + public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs) throws IOException { + if (firstPassTopDocs.scoreDocs.length == 0) { + return firstPassTopDocs; + } + final ScoreDoc[] firstPassResults = getFirstPassDocsRanked(firstPassTopDocs); + + final ScoreDoc[] reranked = rerank(searcher, firstPassResults); + + return new TopDocs(firstPassTopDocs.totalHits, reranked); + } + + /** + * @deprecated Use {@link #rerank(IndexSearcher, ScoreDoc[])} instead. + * From Solr 9.1.0 onwards this method will be removed. + */ + @Deprecated private ScoreDoc[] rerank(IndexSearcher searcher, int topN, ScoreDoc[] firstPassResults) throws IOException { final ScoreDoc[] reranked = new ScoreDoc[topN]; final List leaves = searcher.getIndexReader().leaves(); @@ -145,6 +174,18 @@ private ScoreDoc[] rerank(IndexSearcher searcher, int topN, ScoreDoc[] firstPass return reranked; } + private ScoreDoc[] rerank(IndexSearcher searcher, ScoreDoc[] firstPassResults) throws IOException { + final ScoreDoc[] reranked = new ScoreDoc[firstPassResults.length]; + final List leaves = searcher.getIndexReader().leaves(); + final LTRScoringQuery.ModelWeight modelWeight = (LTRScoringQuery.ModelWeight) searcher + .createWeight(searcher.rewrite(scoringQuery), ScoreMode.COMPLETE, 1); + + scoreFeatures(searcher, modelWeight, firstPassResults, leaves, reranked); + // Must sort all documents that we reranked, and then select the top + Arrays.sort(reranked, scoreComparator); + return reranked; + } + protected static ScoreDoc[] getFirstPassDocsRanked(TopDocs firstPassTopDocs) { final ScoreDoc[] hits = firstPassTopDocs.scoreDocs; Arrays.sort(hits, docComparator); @@ -153,6 +194,10 @@ protected static ScoreDoc[] getFirstPassDocsRanked(TopDocs firstPassTopDocs) { return hits; } + /** + * @deprecated From Solr 9.1.0 onwards this method will be removed. + */ + @Deprecated public void scoreFeatures(IndexSearcher indexSearcher, int topN, LTRScoringQuery.ModelWeight modelWeight, ScoreDoc[] hits, List leaves, ScoreDoc[] reranked) throws IOException { @@ -185,9 +230,39 @@ public void scoreFeatures(IndexSearcher indexSearcher, } } + private void scoreFeatures(IndexSearcher indexSearcher, + LTRScoringQuery.ModelWeight modelWeight, ScoreDoc[] hits, List leaves, + ScoreDoc[] reranked) throws IOException { + + int readerUpto = -1; + int endDoc = 0; + int docBase = 0; + + LTRScoringQuery.ModelWeight.ModelScorer scorer = null; + int hitUpto = 0; + + while (hitUpto < hits.length) { + final ScoreDoc hit = hits[hitUpto]; + final int docID = hit.doc; + LeafReaderContext readerContext = null; + while (docID >= endDoc) { + readerUpto++; + readerContext = leaves.get(readerUpto); + endDoc = readerContext.docBase + readerContext.reader().maxDoc(); + } + // We advanced to another segment + if (readerContext != null) { + docBase = readerContext.docBase; + scorer = modelWeight.scorer(readerContext); + } + reranked[hitUpto] = scoreSingleHit(docBase, hit, docID, scorer); + logSingleHit(indexSearcher, modelWeight, hit.doc, scoringQuery); + hitUpto++; + } + } + /** - * Call this method if the {@link #scoreSingleHit(int, int, int, ScoreDoc, int, org.apache.solr.ltr.LTRScoringQuery.ModelWeight.ModelScorer, ScoreDoc[])} - * method indicated that the document's feature info should be logged. + * Logs a document's feature info. */ protected static void logSingleHit(IndexSearcher indexSearcher, LTRScoringQuery.ModelWeight modelWeight, int docid, LTRScoringQuery scoringQuery) { final FeatureLogger featureLogger = scoringQuery.getFeatureLogger(); @@ -200,7 +275,9 @@ protected static void logSingleHit(IndexSearcher indexSearcher, LTRScoringQuery. * Scores a single document and returns true if the document's feature info should be logged via the * {@link #logSingleHit(IndexSearcher, org.apache.solr.ltr.LTRScoringQuery.ModelWeight, int, LTRScoringQuery)} * method. Feature info logging is only necessary for the topN documents. + * @deprecated From Solr 9.1.0 onwards this method will be removed. */ + @Deprecated protected static boolean scoreSingleHit(int topN, int docBase, int hitUpto, ScoreDoc hit, int docID, LTRScoringQuery.ModelWeight.ModelScorer scorer, ScoreDoc[] reranked) throws IOException { // Scorer for a LTRScoringQuery.ModelWeight should never be null since we always have to // call score @@ -243,6 +320,28 @@ protected static boolean scoreSingleHit(int topN, int docBase, int hitUpto, Scor return logHit; } + /** + * Scores a single document and returns it. + */ + protected static ScoreDoc scoreSingleHit(int docBase, ScoreDoc hit, int docID, LTRScoringQuery.ModelWeight.ModelScorer scorer) throws IOException { + // Scorer for a LTRScoringQuery.ModelWeight should never be null since we always have to + // call score + // even if no feature scorers match, since a model might use that info to + // return a + // non-zero score. Same applies for the case of advancing a LTRScoringQuery.ModelWeight.ModelScorer + // past the target + // doc since the model algorithm still needs to compute a potentially + // non-zero score from blank features. + assert (scorer != null); + final int targetDoc = docID - docBase; + scorer.docID(); + scorer.iterator().advance(targetDoc); + + scorer.getDocInfo().setOriginalDocScore(hit.score); + hit.score = scorer.score(); + return hit; + } + @Override public Explanation explain(IndexSearcher searcher, Explanation firstPassExplanation, int docID) throws IOException { diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/interleaving/LTRInterleavingRescorer.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/interleaving/LTRInterleavingRescorer.java index 799f4d9e36a..9dd3bcd1101 100644 --- a/solr/contrib/ltr/src/java/org/apache/solr/ltr/interleaving/LTRInterleavingRescorer.java +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/interleaving/LTRInterleavingRescorer.java @@ -62,7 +62,11 @@ public LTRInterleavingRescorer( Interleaving interleavingAlgorithm, LTRInterleav * documents to rerank; * @param topN * documents to return; + + * @deprecated Use {@link #rescore(IndexSearcher, TopDocs)} instead. + * From Solr 9.1.0 onwards this method will be removed. */ + @Deprecated @Override public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs, int topN) throws IOException { @@ -91,6 +95,45 @@ public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs, return new TopDocs(firstPassTopDocs.totalHits, interleavedResults); } + /** + * rescores all the documents: + * + * @param searcher + * current IndexSearcher + * @param firstPassTopDocs + * documents to rerank; + */ + @Override + public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs) throws IOException { + if (firstPassTopDocs.scoreDocs.length == 0) { + return firstPassTopDocs; + } + + ScoreDoc[] firstPassResults = null; + if(originalRankingIndex != null) { + firstPassResults = new ScoreDoc[firstPassTopDocs.scoreDocs.length]; + System.arraycopy(firstPassTopDocs.scoreDocs, 0, firstPassResults, 0, firstPassTopDocs.scoreDocs.length); + } + + ScoreDoc[][] reRankedPerModel = rerank(searcher,getFirstPassDocsRanked(firstPassTopDocs)); + if (originalRankingIndex != null) { + reRankedPerModel[originalRankingIndex] = firstPassResults; + } + InterleavingResult interleaved = interleavingAlgorithm.interleave(reRankedPerModel[0], reRankedPerModel[1]); + ScoreDoc[] interleavedResults = interleaved.getInterleavedResults(); + + ArrayList> interleavingPicks = interleaved.getInterleavingPicks(); + rerankingQueries[0].setPickedInterleavingDocIds(interleavingPicks.get(0)); + rerankingQueries[1].setPickedInterleavingDocIds(interleavingPicks.get(1)); + + return new TopDocs(firstPassTopDocs.totalHits, interleavedResults); + } + + /** + * @deprecated Use {@link #rerank(IndexSearcher, ScoreDoc[])} instead. + * From Solr 9.1.0 onwards this method will be removed. + */ + @Deprecated private ScoreDoc[][] rerank(IndexSearcher searcher, int topN, ScoreDoc[] firstPassResults) throws IOException { ScoreDoc[][] reRankedPerModel = new ScoreDoc[rerankingQueries.length][topN]; final List leaves = searcher.getIndexReader().leaves(); @@ -112,6 +155,31 @@ private ScoreDoc[][] rerank(IndexSearcher searcher, int topN, ScoreDoc[] firstPa return reRankedPerModel; } + private ScoreDoc[][] rerank(IndexSearcher searcher, ScoreDoc[] firstPassResults) throws IOException { + ScoreDoc[][] reRankedPerModel = new ScoreDoc[rerankingQueries.length][firstPassResults.length]; + final List leaves = searcher.getIndexReader().leaves(); + LTRScoringQuery.ModelWeight[] modelWeights = new LTRScoringQuery.ModelWeight[rerankingQueries.length]; + for (int i = 0; i < rerankingQueries.length; i++) { + if (originalRankingIndex == null || originalRankingIndex != i) { + modelWeights[i] = (LTRScoringQuery.ModelWeight) searcher + .createWeight(searcher.rewrite(rerankingQueries[i]), ScoreMode.COMPLETE, 1); + } + } + scoreFeatures(searcher, modelWeights, firstPassResults, leaves, reRankedPerModel); + + for (int i = 0; i < rerankingQueries.length; i++) { + if (originalRankingIndex == null || originalRankingIndex != i) { + Arrays.sort(reRankedPerModel[i], scoreComparator); + } + } + + return reRankedPerModel; + } + + /** + * @deprecated From Solr 9.1.0 onwards this method will be removed. + */ + @Deprecated public void scoreFeatures(IndexSearcher indexSearcher, int topN, LTRScoringQuery.ModelWeight[] modelWeights, ScoreDoc[] hits, List leaves, ScoreDoc[][] rerankedPerModel) throws IOException { @@ -153,6 +221,46 @@ public void scoreFeatures(IndexSearcher indexSearcher, } + private void scoreFeatures(IndexSearcher indexSearcher, + LTRScoringQuery.ModelWeight[] modelWeights, ScoreDoc[] hits, List leaves, + ScoreDoc[][] rerankedPerModel) throws IOException { + + int readerUpto = -1; + int endDoc = 0; + int docBase = 0; + int hitUpto = 0; + LTRScoringQuery.ModelWeight.ModelScorer[] scorers = new LTRScoringQuery.ModelWeight.ModelScorer[rerankingQueries.length]; + while (hitUpto < hits.length) { + final ScoreDoc hit = hits[hitUpto]; + final int docID = hit.doc; + LeafReaderContext readerContext = null; + while (docID >= endDoc) { + readerUpto++; + readerContext = leaves.get(readerUpto); + endDoc = readerContext.docBase + readerContext.reader().maxDoc(); + } + + // We advanced to another segment + if (readerContext != null) { + docBase = readerContext.docBase; + for (int i = 0; i < modelWeights.length; i++) { + if (modelWeights[i] != null) { + scorers[i] = modelWeights[i].scorer(readerContext); + } + } + } + for (int i = 0; i < rerankingQueries.length; i++) { + if (modelWeights[i] != null) { + final ScoreDoc hit_i = new ScoreDoc(hit.doc, hit.score, hit.shardIndex); + rerankedPerModel[i][hitUpto] = scoreSingleHit(docBase, hit_i, docID, scorers[i]); + logSingleHit(indexSearcher, modelWeights[i], hit_i.doc, rerankingQueries[i]); + } + } + hitUpto++; + } + + } + @Override public Explanation explain(IndexSearcher searcher, Explanation firstPassExplanation, int docID) throws IOException { diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRReRankingPipeline.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRReRankingPipeline.java index e8a69422ae0..e1d6f194568 100644 --- a/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRReRankingPipeline.java +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRReRankingPipeline.java @@ -132,7 +132,7 @@ public void testRescorer() throws Exception { LTRScoringQuery ltrScoringQuery = new LTRScoringQuery(ltrScoringModel); ltrScoringQuery.setRequest(solrQueryRequest); final LTRRescorer rescorer = new LTRRescorer(ltrScoringQuery); - hits = rescorer.rescore(searcher, hits, 2); + hits = random().nextBoolean() ? rescorer.rescore(searcher, hits) : rescorer.rescore(searcher, hits, 2); // rerank using the field finalScore assertEquals("1", searcher.doc(hits.scoreDocs[0].doc).get("id")); @@ -189,17 +189,24 @@ public void testDifferentTopN() throws IOException { assertEquals("2", searcher.doc(hits.scoreDocs[2].doc).get("id")); assertEquals("3", searcher.doc(hits.scoreDocs[3].doc).get("id")); assertEquals("4", searcher.doc(hits.scoreDocs[4].doc).get("id")); + { + final TopDocs noHits = new TopDocs(hits.totalHits, new ScoreDoc[0]); + final TopDocs noHitsRescored = rescorer.rescore(searcher, noHits); + assertEquals(0, noHitsRescored.scoreDocs.length); + } // test rerank with different topN cuts + // cap firstPassTopDocs length at topN for (int topN = 1; topN <= 5; topN++) { - log.info("rerank {} documents ", topN); + log.info("rerank {} documents, return {} documents", topN, topN); hits = searcher.search(bqBuilder.build(), 10); final ScoreDoc[] slice = new ScoreDoc[topN]; System.arraycopy(hits.scoreDocs, 0, slice, 0, topN); hits = new TopDocs(hits.totalHits, slice); - hits = rescorer.rescore(searcher, hits, topN); + hits = random().nextBoolean() ? rescorer.rescore(searcher, hits) : rescorer.rescore(searcher, hits, topN); + assertEquals(topN, hits.scoreDocs.length); for (int i = topN - 1, j = 0; i >= 0; i--, j++) { if (log.isInfoEnabled()) { log.info("doc {} in pos {}", searcher.doc(hits.scoreDocs[j].doc) @@ -211,6 +218,25 @@ public void testDifferentTopN() throws IOException { } } + + // use full firstPassTopDocs (possibly higher than topN) + for (int topN = 1; topN <= 5; topN++) { + final TopDocs allHits = searcher.search(bqBuilder.build(), 10); + log.info("rerank {} documents, return {} documents", allHits.scoreDocs.length, topN); + + TopDocs rescoredHits = rescorer.rescore(searcher, allHits, topN); + assertEquals(topN, rescoredHits.scoreDocs.length); + for (int i = allHits.scoreDocs.length-1, j = 0; i >= 0 && j < topN; i--, j++) { + if (log.isInfoEnabled()) { + log.info("doc {} in pos {}", searcher.doc(rescoredHits.scoreDocs[j].doc) + .get("id"), j); + } + assertEquals(i, + Integer.parseInt(searcher.doc(rescoredHits.scoreDocs[j].doc).get("id"))); + assertEquals((i + 1) * features.size()*featureWeight, rescoredHits.scoreDocs[j].score, 0.00001); + + } + } } } diff --git a/solr/core/src/java/org/apache/solr/search/ReRankCollector.java b/solr/core/src/java/org/apache/solr/search/ReRankCollector.java index 11108e74e50..42f05917da1 100644 --- a/solr/core/src/java/org/apache/solr/search/ReRankCollector.java +++ b/solr/core/src/java/org/apache/solr/search/ReRankCollector.java @@ -112,8 +112,12 @@ public TopDocs topDocs(int start, int howMany) { mainDocs.scoreDocs = reRankScoreDocs; - TopDocs rescoredDocs = reRankQueryRescorer - .rescore(searcher, mainDocs, mainDocs.scoreDocs.length); + final TopDocs rescoredDocs; + if (reRankQueryRescorer instanceof ReRankRescorer) { + rescoredDocs = ((ReRankRescorer) reRankQueryRescorer).rescore(searcher, mainDocs); + } else { + rescoredDocs = reRankQueryRescorer.rescore(searcher, mainDocs, mainDocs.scoreDocs.length); + } //Lower howMany to return if we've collected fewer documents. howMany = Math.min(howMany, mainScoreDocs.length); diff --git a/solr/core/src/java/org/apache/solr/search/ReRankRescorer.java b/solr/core/src/java/org/apache/solr/search/ReRankRescorer.java new file mode 100644 index 00000000000..46876b68df0 --- /dev/null +++ b/solr/core/src/java/org/apache/solr/search/ReRankRescorer.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.search; + +import java.io.IOException; + +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Rescorer; +import org.apache.lucene.search.TopDocs; + +/** + * This class is a variant of Lucene's {@link Rescorer} class for use by the + * {@link ReRankCollector} to rescore all results in a {@link TopDocs} object + * from an original query. + */ +public abstract class ReRankRescorer extends Rescorer { + + /** + * Rescore an initial first-pass {@link TopDocs}. + * + * @param searcher {@link IndexSearcher} used to produce the first pass topDocs + * @param firstPassTopDocs Hits from the first pass search that are to be rescored. + * It's very important that these hits were produced by the provided searcher; + * otherwise the doc IDs will not match! + */ + public abstract TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs) + throws IOException; + + /** + * Throws an {@link UnsupportedOperationException} exception. + * @deprecated Use {@link #rescore(IndexSearcher, TopDocs)} instead. + * From Solr 9.1.0 onwards this method will be final. + */ + @Deprecated + public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs, int topN) + throws IOException + { + throw new UnsupportedOperationException(); + } + +}