diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRReRankingPipeline.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRReRankingPipeline.java index e8a69422ae0..5f34fe84e2b 100644 --- a/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRReRankingPipeline.java +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRReRankingPipeline.java @@ -192,14 +192,16 @@ public void testDifferentTopN() throws IOException { // test rerank with different topN cuts + // cap firstPassTopDocs length at topN for (int topN = 1; topN <= 5; topN++) { - log.info("rerank {} documents ", topN); + log.info("rerank {} documents, return {} documents", topN, topN); hits = searcher.search(bqBuilder.build(), 10); final ScoreDoc[] slice = new ScoreDoc[topN]; System.arraycopy(hits.scoreDocs, 0, slice, 0, topN); hits = new TopDocs(hits.totalHits, slice); hits = rescorer.rescore(searcher, hits, topN); + assertEquals(topN, hits.scoreDocs.length); for (int i = topN - 1, j = 0; i >= 0; i--, j++) { if (log.isInfoEnabled()) { log.info("doc {} in pos {}", searcher.doc(hits.scoreDocs[j].doc) @@ -211,6 +213,25 @@ public void testDifferentTopN() throws IOException { } } + + // use full firstPassTopDocs (possibly higher than topN) + for (int topN = 1; topN <= 5; topN++) { + final TopDocs allHits = searcher.search(bqBuilder.build(), 10); + log.info("rerank {} documents, return {} documents", allHits.scoreDocs.length, topN); + + TopDocs rescoredHits = rescorer.rescore(searcher, allHits, topN); + assertEquals(topN, rescoredHits.scoreDocs.length); + for (int i = allHits.scoreDocs.length-1, j = 0; i >= 0 && j < topN; i--, j++) { + if (log.isInfoEnabled()) { + log.info("doc {} in pos {}", searcher.doc(rescoredHits.scoreDocs[j].doc) + .get("id"), j); + } + assertEquals(i, + Integer.parseInt(searcher.doc(rescoredHits.scoreDocs[j].doc).get("id"))); + assertEquals((i + 1) * features.size()*featureWeight, rescoredHits.scoreDocs[j].score, 0.00001); + + } + } } }