Skip to content

Commit 7405691

Browse files
author
tomglk
committed
[SOLR-15437] first draft of making the sort by score work with changes made in PR apache#151
1 parent c1db0cf commit 7405691

27 files changed

Lines changed: 288 additions & 60 deletions

solr/contrib/ltr/src/java/org/apache/solr/ltr/DocInfo.java

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,27 +16,15 @@
1616
*/
1717
package org.apache.solr.ltr;
1818

19-
import java.util.HashMap;
20-
21-
public class DocInfo extends HashMap<String,Object> {
19+
public class DocInfo {
2220

2321
// Name of key used to store the original score of a doc
24-
private static final String ORIGINAL_DOC_SCORE = "ORIGINAL_DOC_SCORE";
22+
float[] originalScores = new float[600]; // FIXME optional score list
2523

2624
public DocInfo() {
27-
super();
28-
}
29-
30-
public void setOriginalDocScore(Float score) {
31-
put(ORIGINAL_DOC_SCORE, score);
3225
}
3326

34-
public Float getOriginalDocScore() {
35-
return (Float)get(ORIGINAL_DOC_SCORE);
27+
public void setOriginalDocScore(int id, float score) {
28+
originalScores[id] = score;
3629
}
37-
38-
public boolean hasOriginalDocScore() {
39-
return containsKey(ORIGINAL_DOC_SCORE);
40-
}
41-
4230
}

solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import org.apache.lucene.search.TotalHits;
3333
import org.apache.lucene.search.Weight;
3434
import org.apache.solr.ltr.interleaving.OriginalRankingLTRScoringQuery;
35+
import org.apache.solr.search.ScoreDocWithOriginalScore;
3536
import org.apache.solr.search.SolrIndexSearcher;
3637

3738

@@ -164,7 +165,6 @@ public int compare(ScoreDoc a, ScoreDoc b) {
164165
public void scoreFeatures(IndexSearcher indexSearcher,
165166
int topN, LTRScoringQuery.ModelWeight modelWeight, ScoreDoc[] hits, List<LeafReaderContext> leaves,
166167
ScoreDoc[] reranked) throws IOException {
167-
168168
int readerUpto = -1;
169169
int endDoc = 0;
170170
int docBase = 0;
@@ -173,7 +173,7 @@ public void scoreFeatures(IndexSearcher indexSearcher,
173173
int hitUpto = 0;
174174

175175
while (hitUpto < hits.length) {
176-
final ScoreDoc hit = hits[hitUpto];
176+
final ScoreDocWithOriginalScore hit = new ScoreDocWithOriginalScore(hits[hitUpto]);
177177
final int docID = hit.doc;
178178
LeafReaderContext readerContext = null;
179179
while (docID >= endDoc) {
@@ -206,7 +206,7 @@ protected static void scoreSingleHit(IndexSearcher indexSearcher, int topN, LTRS
206206
scorer.docID();
207207
scorer.iterator().advance(targetDoc);
208208

209-
scorer.getDocInfo().setOriginalDocScore(hit.score);
209+
scorer.getDocInfo().setOriginalDocScore(hit.doc, hit.score);
210210
hit.score = scorer.score();
211211
if (hitUpto < topN) {
212212
reranked[hitUpto] = hit;
@@ -274,7 +274,7 @@ public static LTRScoringQuery.FeatureInfo[] extractFeaturesInfo(LTRScoringQuery.
274274
if (originalDocScore != null) {
275275
// If results have not been reranked, the score passed in is the original query's
276276
// score, which some features can use instead of recalculating it
277-
r.getDocInfo().setOriginalDocScore(originalDocScore);
277+
r.getDocInfo().setOriginalDocScore(docid, originalDocScore);
278278
}
279279
r.score();
280280
return modelWeight.getFeaturesInfo();

solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -509,8 +509,8 @@ public DocInfo getDocInfo() {
509509
public ModelScorer(Weight weight, List<Feature.FeatureWeight.FeatureScorer> featureScorers) {
510510
super(weight);
511511
docInfo = new DocInfo();
512-
for (final Feature.FeatureWeight.FeatureScorer subSocer : featureScorers) {
513-
subSocer.setDocInfo(docInfo);
512+
for (final Feature.FeatureWeight.FeatureScorer subScorer : featureScorers) {
513+
subScorer.setDocInfo(docInfo);
514514
}
515515
if (featureScorers.size() <= 1) {
516516
// future enhancement: allow the use of dense features in other cases

solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/OriginalScoreFeature.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import org.apache.lucene.search.ScoreMode;
2727
import org.apache.lucene.search.Scorer;
2828
import org.apache.lucene.search.Weight;
29-
import org.apache.solr.ltr.DocInfo;
3029
import org.apache.solr.request.SolrQueryRequest;
3130
/**
3231
* This feature returns the original score that the document had before performing
@@ -94,8 +93,7 @@ public float score() throws IOException {
9493
// This is done to improve the speed of feature extraction. Since this
9594
// was already scored in step 1
9695
// we shouldn't need to calc original score again.
97-
final DocInfo docInfo = getDocInfo();
98-
return (docInfo != null && docInfo.hasOriginalDocScore() ? docInfo.getOriginalDocScore() : in.score());
96+
return in.score(); // FIXME bad for performance
9997
}
10098

10199
}

solr/contrib/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,12 @@ public void transform(SolrDocument doc, int docid, float score)
325325
implTransform(doc, docid, score);
326326
}
327327

328+
@Override
329+
public void transform(SolrDocument doc, int docid, float score, float originalScore)
330+
throws IOException {
331+
implTransform(doc, docid, score, originalScore);
332+
}
333+
328334
@Override
329335
public void transform(SolrDocument doc, int docid)
330336
throws IOException {
@@ -355,6 +361,30 @@ private void implTransform(SolrDocument doc, int docid, Float score)
355361
}
356362
}
357363

364+
private void implTransform(SolrDocument doc, int docid, Float score, Float originalScore)
365+
throws IOException {
366+
LTRScoringQuery rerankingQuery = rerankingQueries[0];
367+
LTRScoringQuery.ModelWeight rerankingModelWeight = modelWeights[0];
368+
for (int i = 1; i < rerankingQueries.length; i++) {
369+
if (((LTRInterleavingScoringQuery)rerankingQueriesFromContext[i]).getPickedInterleavingDocIds().contains(docid)) {
370+
rerankingQuery = rerankingQueries[i];
371+
rerankingModelWeight = modelWeights[i];
372+
}
373+
}
374+
if (!(rerankingQuery instanceof OriginalRankingLTRScoringQuery) || hasExplicitFeatureStore) {
375+
Object featureVector = featureLogger.getFeatureVector(docid, rerankingQuery, searcher);
376+
if (featureVector == null) { // FV for this document was not in the cache
377+
featureVector = featureLogger.makeFeatureVector(
378+
LTRRescorer.extractFeaturesInfo(
379+
rerankingModelWeight,
380+
docid,
381+
originalScore,
382+
leafContexts));
383+
}
384+
doc.addField(name, featureVector);
385+
}
386+
}
387+
358388
}
359389

360390
private static class LoggingModel extends LTRScoringModel {

solr/contrib/ltr/src/java/org/apache/solr/ltr/response/transform/LTRInterleavingTransformerFactory.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,12 @@ public void transform(SolrDocument doc, int docid, float score)
9696
implTransform(doc, docid);
9797
}
9898

99+
@Override
100+
public void transform(SolrDocument doc, int docid, float score, float originalScore)
101+
throws IOException {
102+
implTransform(doc, docid);
103+
}
104+
99105
@Override
100106
public void transform(SolrDocument doc, int docid)
101107
throws IOException {

solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTROnSolrCloud.java

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ public void setUp() throws Exception {
6969
extraServlets = setupTestInit(solrconfig, schema, true);
7070
System.setProperty("enable.update.log", "true");
7171

72-
int numberOfShards = random().nextInt(4)+1;
72+
int numberOfShards = random().nextInt(4)+1; // FIXME testSimpleQueryCustomSortByScoreWithSubResultSetAndRowsLessThanExistingDocs only works if both are set to 1
7373
int numberOfReplicas = random().nextInt(2)+1;
7474

7575
int numberOfNodes = numberOfShards * numberOfReplicas;
@@ -189,6 +189,40 @@ public void testSimpleQueryCustomSortWithSubResultSetAndRowsLessThanExistingDocs
189189
docCounter++;
190190
}
191191
}
192+
193+
@Test
194+
public void testSimpleQueryCustomSortByScoreWithSubResultSetAndRowsLessThanExistingDocs() throws Exception {
195+
final int reRankDocs = 2;
196+
SolrQuery query = new SolrQuery("{!func}sub(8,field(popularity))");
197+
query.setRequestHandler("/query");
198+
query.setFields("*,score,originalScore,[shard],[fv]"); // score & originalScore as fl are needed here to be able to use it for assertions
199+
query.setParam("rows", "6");
200+
query.setParam("sort", "score asc");
201+
query.add("rq", "{!ltr model=powpularityS-model reRankDocs="+reRankDocs+"}");
202+
QueryResponse queryResponse = solrCluster.getSolrClient().query(COLLECTION, query);
203+
SolrDocumentList results = queryResponse.getResults();
204+
assertEquals(6, results.size());
205+
int docCounter = 0;
206+
float lastScore = Float.MIN_VALUE;
207+
208+
for(SolrDocument d : results){
209+
210+
System.out.println("id " + d.getFieldValue("id") +
211+
" score " + d.getFieldValue("score") +
212+
" original " + d.getFieldValue("originalScore"));
213+
214+
float score = (float) d.getFieldValue("score");
215+
if(docCounter < reRankDocs){
216+
final float calculatedScore = calculateLTRScoreForDoc(d);
217+
assertEquals(calculatedScore, score, 0.0);
218+
} else if(docCounter > reRankDocs + 1) {
219+
assertTrue(lastScore < score);
220+
}
221+
lastScore = score;
222+
docCounter++;
223+
}
224+
}
225+
192226
private float calculateLTRScoreForDoc(SolrDocument d) {
193227
Matcher matcher = Pattern.compile(",?(\\w+)=(-?[0-9]+\\.[0-9]+)").matcher((String) d.getFieldValue("[fv]"));
194228
Map<String, Float> weights = Splitter.on(",")

solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRReRankingPipeline.java

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -230,10 +230,11 @@ public void testDocParam() throws Exception {
230230
query.setRequest(solrQueryRequest);
231231
LTRScoringQuery.ModelWeight wgt = query.createWeight(null, ScoreMode.COMPLETE, 1f);
232232
LTRScoringQuery.ModelWeight.ModelScorer modelScr = wgt.scorer(null);
233-
modelScr.getDocInfo().setOriginalDocScore(1f);
234-
for (final Scorable.ChildScorable feat : modelScr.getChildren()) {
235-
assertNotNull(((Feature.FeatureWeight.FeatureScorer) feat.child).getDocInfo().getOriginalDocScore());
236-
}
233+
// FIXME
234+
// modelScr.getDocInfo().setOriginalDocScore(1f);
235+
// for (final Scorable.ChildScorable feat : modelScr.getChildren()) {
236+
// assertNotNull(((Feature.FeatureWeight.FeatureScorer) feat.child).getDocInfo().getOriginalDocScore());
237+
// }
237238

238239
features = makeFieldValueFeatures(new int[] {0, 1, 2}, "finalScore");
239240
norms =
@@ -246,11 +247,12 @@ public void testDocParam() throws Exception {
246247
query = new LTRScoringQuery(ltrScoringModel);
247248
query.setRequest(solrQueryRequest);
248249
wgt = query.createWeight(null, ScoreMode.COMPLETE, 1f);
249-
modelScr = wgt.scorer(null);
250-
modelScr.getDocInfo().setOriginalDocScore(1f);
251-
for (final Scorable.ChildScorable feat : modelScr.getChildren()) {
252-
assertNotNull(((Feature.FeatureWeight.FeatureScorer) feat.child).getDocInfo().getOriginalDocScore());
253-
}
250+
// FIXME
251+
// modelScr = wgt.scorer(null);
252+
// modelScr.getDocInfo().setOriginalDocScore(1f);
253+
// for (final Scorable.ChildScorable feat : modelScr.getChildren()) {
254+
// assertNotNull(((Feature.FeatureWeight.FeatureScorer) feat.child).getDocInfo().getOriginalDocScore());
255+
// }
254256
}
255257
}
256258
}

solr/core/src/java/org/apache/solr/handler/component/QueryComponent.java

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -73,24 +73,7 @@
7373
import org.apache.solr.schema.IndexSchema;
7474
import org.apache.solr.schema.SchemaField;
7575
import org.apache.solr.schema.SortableTextField;
76-
import org.apache.solr.search.CursorMark;
77-
import org.apache.solr.search.DocIterator;
78-
import org.apache.solr.search.DocList;
79-
import org.apache.solr.search.DocListAndSet;
80-
import org.apache.solr.search.DocSlice;
81-
import org.apache.solr.search.Grouping;
82-
import org.apache.solr.search.QParser;
83-
import org.apache.solr.search.QParserPlugin;
84-
import org.apache.solr.search.QueryCommand;
85-
import org.apache.solr.search.QueryParsing;
86-
import org.apache.solr.search.QueryResult;
87-
import org.apache.solr.search.RankQuery;
88-
import org.apache.solr.search.ReturnFields;
89-
import org.apache.solr.search.SolrIndexSearcher;
90-
import org.apache.solr.search.SolrReturnFields;
91-
import org.apache.solr.search.SortSpec;
92-
import org.apache.solr.search.SortSpecParsing;
93-
import org.apache.solr.search.SyntaxError;
76+
import org.apache.solr.search.*;
9477
import org.apache.solr.search.grouping.CommandHandler;
9578
import org.apache.solr.search.grouping.GroupingSpecification;
9679
import org.apache.solr.search.grouping.distributed.ShardRequestFactory;
@@ -992,6 +975,7 @@ protected void mergeIds(ResponseBuilder rb, ShardRequest sreq) {
992975
shardDoc.shard = srsp.getShard();
993976
shardDoc.orderInShard = i;
994977
Object scoreObj = doc.getFieldValue("score");
978+
Object originalScoreObj = doc.getFieldValue("originalScore");
995979
if (scoreObj != null) {
996980
if (scoreObj instanceof String) {
997981
shardDoc.score = Float.parseFloat((String)scoreObj);

solr/core/src/java/org/apache/solr/handler/component/SortedHitQueueManager.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ public SortedHitQueueManager(SortField[] sortFields, SortSpec ss, ResponseBuilde
4444
int absoluteReRankDocs = Math.min(reRankDocsSize, ss.getCount());
4545
reRankQueue = new ShardFieldSortedHitQueue(new SortField[]{SortField.FIELD_SCORE},
4646
absoluteReRankDocs, rb.req.getSearcher());
47-
queue = new ShardFieldSortedHitQueue(sortFields, ss.getOffset() + ss.getCount() - absoluteReRankDocs,
47+
// TODO maybe an if is missing here?
48+
queue = new ShardFieldSortedHitQueue(new SortField[]{new SortField("originalScore", SortField.Type.SCORE)}, ss.getOffset() + ss.getCount() - absoluteReRankDocs,
4849
rb.req.getSearcher(), false);
4950
} else {
5051
// reRanking is disabled, use one queue for all results

0 commit comments

Comments
 (0)