Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SOLR-15437: ReRanking / LTR does not work in SolrCloud with custom sort by score #171

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 4 additions & 16 deletions solr/contrib/ltr/src/java/org/apache/solr/ltr/DocInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,15 @@
*/
package org.apache.solr.ltr;

import java.util.HashMap;

public class DocInfo extends HashMap<String,Object> {
public class DocInfo {

// Name of key used to store the original score of a doc
private static final String ORIGINAL_DOC_SCORE = "ORIGINAL_DOC_SCORE";
float[] originalScores = new float[600]; // FIXME optional score list

public DocInfo() {
super();
}

public void setOriginalDocScore(Float score) {
put(ORIGINAL_DOC_SCORE, score);
}

public Float getOriginalDocScore() {
return (Float)get(ORIGINAL_DOC_SCORE);
public void setOriginalDocScore(int id, float score) {
originalScores[id] = score;
}

public boolean hasOriginalDocScore() {
return containsKey(ORIGINAL_DOC_SCORE);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.Weight;
import org.apache.solr.ltr.interleaving.OriginalRankingLTRScoringQuery;
import org.apache.solr.search.ScoreDocWithOriginalScore;
import org.apache.solr.search.SolrIndexSearcher;


Expand Down Expand Up @@ -164,7 +165,6 @@ public int compare(ScoreDoc a, ScoreDoc b) {
public void scoreFeatures(IndexSearcher indexSearcher,
int topN, LTRScoringQuery.ModelWeight modelWeight, ScoreDoc[] hits, List<LeafReaderContext> leaves,
ScoreDoc[] reranked) throws IOException {

int readerUpto = -1;
int endDoc = 0;
int docBase = 0;
Expand All @@ -173,7 +173,7 @@ public void scoreFeatures(IndexSearcher indexSearcher,
int hitUpto = 0;

while (hitUpto < hits.length) {
final ScoreDoc hit = hits[hitUpto];
final ScoreDocWithOriginalScore hit = new ScoreDocWithOriginalScore(hits[hitUpto]);
final int docID = hit.doc;
LeafReaderContext readerContext = null;
while (docID >= endDoc) {
Expand Down Expand Up @@ -206,7 +206,7 @@ protected static void scoreSingleHit(IndexSearcher indexSearcher, int topN, LTRS
scorer.docID();
scorer.iterator().advance(targetDoc);

scorer.getDocInfo().setOriginalDocScore(hit.score);
scorer.getDocInfo().setOriginalDocScore(hit.doc, hit.score);
hit.score = scorer.score();
if (hitUpto < topN) {
reranked[hitUpto] = hit;
Expand Down Expand Up @@ -274,7 +274,7 @@ public static LTRScoringQuery.FeatureInfo[] extractFeaturesInfo(LTRScoringQuery.
if (originalDocScore != null) {
// If results have not been reranked, the score passed in is the original query's
// score, which some features can use instead of recalculating it
r.getDocInfo().setOriginalDocScore(originalDocScore);
r.getDocInfo().setOriginalDocScore(docid, originalDocScore);
}
r.score();
return modelWeight.getFeaturesInfo();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -509,8 +509,8 @@ public DocInfo getDocInfo() {
public ModelScorer(Weight weight, List<Feature.FeatureWeight.FeatureScorer> featureScorers) {
super(weight);
docInfo = new DocInfo();
for (final Feature.FeatureWeight.FeatureScorer subSocer : featureScorers) {
subSocer.setDocInfo(docInfo);
for (final Feature.FeatureWeight.FeatureScorer subScorer : featureScorers) {
subScorer.setDocInfo(docInfo);
}
if (featureScorers.size() <= 1) {
// future enhancement: allow the use of dense features in other cases
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.apache.solr.ltr.DocInfo;
import org.apache.solr.request.SolrQueryRequest;
/**
* This feature returns the original score that the document had before performing
Expand Down Expand Up @@ -94,8 +93,7 @@ public float score() throws IOException {
// This is done to improve the speed of feature extraction. Since this
// was already scored in step 1
// we shouldn't need to calc original score again.
final DocInfo docInfo = getDocInfo();
return (docInfo != null && docInfo.hasOriginalDocScore() ? docInfo.getOriginalDocScore() : in.score());
return in.score(); // FIXME bad for performance
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,12 @@ public void transform(SolrDocument doc, int docid, float score)
implTransform(doc, docid, score);
}

@Override
public void transform(SolrDocument doc, int docid, float score, float originalScore)
throws IOException {
implTransform(doc, docid, score, originalScore);
}

@Override
public void transform(SolrDocument doc, int docid)
throws IOException {
Expand Down Expand Up @@ -355,6 +361,30 @@ private void implTransform(SolrDocument doc, int docid, Float score)
}
}

private void implTransform(SolrDocument doc, int docid, Float score, Float originalScore)
throws IOException {
LTRScoringQuery rerankingQuery = rerankingQueries[0];
LTRScoringQuery.ModelWeight rerankingModelWeight = modelWeights[0];
for (int i = 1; i < rerankingQueries.length; i++) {
if (((LTRInterleavingScoringQuery)rerankingQueriesFromContext[i]).getPickedInterleavingDocIds().contains(docid)) {
rerankingQuery = rerankingQueries[i];
rerankingModelWeight = modelWeights[i];
}
}
if (!(rerankingQuery instanceof OriginalRankingLTRScoringQuery) || hasExplicitFeatureStore) {
Object featureVector = featureLogger.getFeatureVector(docid, rerankingQuery, searcher);
if (featureVector == null) { // FV for this document was not in the cache
featureVector = featureLogger.makeFeatureVector(
LTRRescorer.extractFeaturesInfo(
rerankingModelWeight,
docid,
originalScore,
leafContexts));
}
doc.addField(name, featureVector);
}
}

}

private static class LoggingModel extends LTRScoringModel {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ public void transform(SolrDocument doc, int docid, float score)
implTransform(doc, docid);
}

@Override
public void transform(SolrDocument doc, int docid, float score, float originalScore)
throws IOException {
implTransform(doc, docid);
}

@Override
public void transform(SolrDocument doc, int docid)
throws IOException {
Expand Down
176 changes: 169 additions & 7 deletions solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTROnSolrCloud.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,17 @@
package org.apache.solr.ltr;

import java.io.File;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.SortedMap;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import com.google.common.base.Splitter;
import org.apache.commons.io.FileUtils;
import org.apache.solr.client.solrj.SolrQuery;
import org.apache.solr.client.solrj.embedded.JettyConfig;
Expand All @@ -29,6 +35,8 @@
import org.apache.solr.client.solrj.response.CollectionAdminResponse;
import org.apache.solr.client.solrj.response.QueryResponse;
import org.apache.solr.cloud.MiniSolrCloudCluster;
import org.apache.solr.common.SolrDocument;
import org.apache.solr.common.SolrDocumentList;
import org.apache.solr.common.SolrInputDocument;
import org.apache.solr.common.cloud.ZkStateReader;
import org.apache.solr.ltr.feature.FieldValueFeature;
Expand All @@ -50,24 +58,25 @@ public class TestLTROnSolrCloud extends TestRerankBase {
String schema = "schema.xml";

SortedMap<ServletHolder,String> extraServlets = null;

private final static String MODEL_WEIGHTS = "\"powpularityS\":1.0,\"c3\":1.0,\"original\":0.1," +
"\"dvIntFieldFeature\":0.1,\"dvLongFieldFeature\":0.1," +
"\"dvFloatFieldFeature\":0.1,\"dvDoubleFieldFeature\":0.1,\"dvStrNumFieldFeature\":0.1,\"dvStrBoolFieldFeature\":0.1";

@Override
public void setUp() throws Exception {
super.setUp();
extraServlets = setupTestInit(solrconfig, schema, true);
System.setProperty("enable.update.log", "true");

int numberOfShards = random().nextInt(4)+1;
int numberOfShards = random().nextInt(4)+1; // FIXME testSimpleQueryCustomSortByScoreWithSubResultSetAndRowsLessThanExistingDocs only works if both are set to 1
int numberOfReplicas = random().nextInt(2)+1;

int numberOfNodes = numberOfShards * numberOfReplicas;

setupSolrCluster(numberOfShards, numberOfReplicas, numberOfNodes);


}


@Override
public void tearDown() throws Exception {
restTestHarness.close();
Expand All @@ -76,6 +85,161 @@ public void tearDown() throws Exception {
super.tearDown();
}

@Test
public void testSimpleQueryCustomSort() throws Exception {
SolrQuery query = new SolrQuery("*:*");
query.setRequestHandler("/query");
query.setFields("*,[shard]");
query.setParam("rows", "8");
query.setParam("sort", "id asc");
query.add("rq", "{!ltr model=powpularityS-model reRankDocs=8}");

QueryResponse queryResponse =
solrCluster.getSolrClient().query(COLLECTION, query);
assertEquals(8, queryResponse.getResults().getNumFound());
assertEquals("8", queryResponse.getResults().get(0).get("id").toString());
assertEquals("7", queryResponse.getResults().get(1).get("id").toString());
assertEquals("6", queryResponse.getResults().get(2).get("id").toString());
assertEquals("5", queryResponse.getResults().get(3).get("id").toString());
assertEquals("4", queryResponse.getResults().get(4).get("id").toString());
assertEquals("3", queryResponse.getResults().get(5).get("id").toString());
assertEquals("2", queryResponse.getResults().get(6).get("id").toString());
assertEquals("1", queryResponse.getResults().get(7).get("id").toString());
}

@Test
public void testSimpleQueryCustomSortWithSubResultSet() throws Exception {
final int reRankDocs = 2;
SolrQuery query = new SolrQuery("*:*");
query.setRequestHandler("/query");
query.setFields("*,score,[shard],[fv]");
query.setParam("rows", "8");
query.setParam("sort", "id asc");
query.add("rq", "{!ltr model=powpularityS-model reRankDocs="+reRankDocs+"}");

QueryResponse queryResponse = solrCluster.getSolrClient().query(COLLECTION, query);
SolrDocumentList results = queryResponse.getResults();
assertEquals(8, results.getNumFound());

// save order to use it later
List<String> expectedDocIdOrder = new ArrayList<>();

int docCounter = 0;
float lastScore = Float.MAX_VALUE;
double lastId = 0d;
for(SolrDocument d : results){
float score = (float) d.getFieldValue("score");

double id = Double.parseDouble((String) d.getFirstValue("id"));
expectedDocIdOrder.add((String) d.getFirstValue("id"));
if(docCounter < reRankDocs){
final float calculatedScore = calculateLTRScoreForDoc(d);
assertEquals(calculatedScore, score, 0.0);
assertTrue(lastScore > score);
} else if(docCounter > reRankDocs + 1) {
assertTrue(lastId < id);
}
lastScore = score;
lastId = id;

docCounter++;
}

query.setFields("*,[shard],[fv]");

queryResponse = solrCluster.getSolrClient().query(COLLECTION, query);
results = queryResponse.getResults();
assertEquals(8, results.getNumFound());

List<String> docIdOrder = results.stream()
.map(document -> (String) document.getFirstValue("id"))
.collect(Collectors.toList());

// assert that sorting is correct when we do not return the score via fl param
assertEquals(expectedDocIdOrder, docIdOrder);
}

@Test
public void testSimpleQueryCustomSortWithSubResultSetAndRowsLessThanExistingDocs() throws Exception {
final int reRankDocs = 2;
SolrQuery query = new SolrQuery("*:*");
query.setRequestHandler("/query");
query.setFields("*,score,[shard],[fv]"); // score as fl is needed here to be able to use it for assertions
query.setParam("rows", "6");
query.setParam("sort", "id asc");
query.add("rq", "{!ltr model=powpularityS-model reRankDocs="+reRankDocs+"}");
QueryResponse queryResponse = solrCluster.getSolrClient().query(COLLECTION, query);
SolrDocumentList results = queryResponse.getResults();
assertEquals(6, results.size());
int docCounter = 0;
float lastScore = Float.MAX_VALUE;
double lastId = 0d;
for(SolrDocument d : results){
float score = (float) d.getFieldValue("score");
double id = Double.parseDouble((String) d.getFirstValue("id"));
if(docCounter < reRankDocs){
final float calculatedScore = calculateLTRScoreForDoc(d);
assertEquals(calculatedScore, score, 0.0);
assertTrue(lastScore > score);
} else if(docCounter > reRankDocs + 1) {
assertTrue(lastId < id);
}
lastScore = score;
lastId = id;
docCounter++;
}
}

@Test
public void testSimpleQueryCustomSortByScoreWithSubResultSetAndRowsLessThanExistingDocs() throws Exception {
final int reRankDocs = 2;
SolrQuery query = new SolrQuery("{!func}sub(8,field(popularity))");
query.setRequestHandler("/query");
query.setFields("*,score,originalScore,[shard],[fv]"); // score & originalScore as fl are needed here to be able to use it for assertions
query.setParam("rows", "6");
query.setParam("sort", "score asc");
query.add("rq", "{!ltr model=powpularityS-model reRankDocs="+reRankDocs+"}");
QueryResponse queryResponse = solrCluster.getSolrClient().query(COLLECTION, query);
SolrDocumentList results = queryResponse.getResults();
assertEquals(6, results.size());
int docCounter = 0;
float lastScore = Float.MIN_VALUE;

for(SolrDocument d : results){

System.out.println("id " + d.getFieldValue("id") +
" score " + d.getFieldValue("score") +
" original " + d.getFieldValue("originalScore"));

float score = (float) d.getFieldValue("score");
if(docCounter < reRankDocs){
final float calculatedScore = calculateLTRScoreForDoc(d);
assertEquals(calculatedScore, score, 0.0);
} else if(docCounter > reRankDocs + 1) {
assertTrue(lastScore < score);
}
lastScore = score;
docCounter++;
}
}

private float calculateLTRScoreForDoc(SolrDocument d) {
Matcher matcher = Pattern.compile(",?(\\w+)=(-?[0-9]+\\.[0-9]+)").matcher((String) d.getFieldValue("[fv]"));
Map<String, Float> weights = Splitter.on(",")
.splitToList(MODEL_WEIGHTS)
.stream()
.map(fieldWithWeight -> fieldWithWeight.split(":"))
.collect(Collectors.toMap(fieldAndValue -> fieldAndValue[0].replaceAll("\"", ""),
fieldAndValue -> Float.parseFloat(fieldAndValue[1])));

float score = 0.0f;
while(matcher.find()) {
score += Float.parseFloat(matcher.group(2)) * weights.get(matcher.group(1));
}

return score;
}

@Test
// commented 4-Sep-2018 @LuceneTestCase.BadApple(bugUrl="https://issues.apache.org/jira/browse/SOLR-12028") // 2-Aug-2018
// commented out on: 24-Dec-2018 @BadApple(bugUrl="https://issues.apache.org/jira/browse/SOLR-12028") // 14-Oct-2018
Expand Down Expand Up @@ -298,9 +462,7 @@ private void loadModelsAndFeatures() throws Exception {
final String featureStore = "test";
final String[] featureNames = new String[]{"powpularityS", "c3", "original", "dvIntFieldFeature",
"dvLongFieldFeature", "dvFloatFieldFeature", "dvDoubleFieldFeature", "dvStrNumFieldFeature", "dvStrBoolFieldFeature"};
final String jsonModelParams = "{\"weights\":{\"powpularityS\":1.0,\"c3\":1.0,\"original\":0.1," +
"\"dvIntFieldFeature\":0.1,\"dvLongFieldFeature\":0.1," +
"\"dvFloatFieldFeature\":0.1,\"dvDoubleFieldFeature\":0.1,\"dvStrNumFieldFeature\":0.1,\"dvStrBoolFieldFeature\":0.1}}";
final String jsonModelParams = "{\"weights\":{" + MODEL_WEIGHTS + "}}";

loadFeature(
featureNames[0],
Expand Down
Loading