diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/GraphLookupOperation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/GraphLookupOperation.java index a2fe238627..02b3fb4ea8 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/GraphLookupOperation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/GraphLookupOperation.java @@ -46,7 +46,7 @@ public class GraphLookupOperation implements InheritsFieldsAggregationOperation private static final Set> ALLOWED_START_TYPES = new HashSet>( Arrays.> asList(AggregationExpression.class, String.class, Field.class, Document.class)); - private final String from; + private final Object from; private final List startWith; private final Field connectFrom; private final Field connectTo; @@ -55,7 +55,7 @@ public class GraphLookupOperation implements InheritsFieldsAggregationOperation private final @Nullable Field depthField; private final @Nullable CriteriaDefinition restrictSearchWithMatch; - private GraphLookupOperation(String from, List startWith, Field connectFrom, Field connectTo, Field as, + private GraphLookupOperation(Object from, List startWith, Field connectFrom, Field connectTo, Field as, @Nullable Long maxDepth, @Nullable Field depthField, @Nullable CriteriaDefinition restrictSearchWithMatch) { this.from = from; @@ -82,7 +82,7 @@ public Document toDocument(AggregationOperationContext context) { Document graphLookup = new Document(); - graphLookup.put("from", from); + graphLookup.put("from", getCollectionName(context)); List mappedStartWith = new ArrayList<>(startWith.size()); @@ -99,7 +99,7 @@ public Document toDocument(AggregationOperationContext context) { graphLookup.put("startWith", mappedStartWith.size() == 1 ? mappedStartWith.iterator().next() : mappedStartWith); - graphLookup.put("connectFromField", connectFrom.getTarget()); + graphLookup.put("connectFromField", getForeignFieldName(context)); graphLookup.put("connectToField", connectTo.getTarget()); graphLookup.put("as", as.getName()); @@ -118,6 +118,16 @@ public Document toDocument(AggregationOperationContext context) { return new Document(getOperator(), graphLookup); } + String getCollectionName(AggregationOperationContext context) { + return from instanceof Class type ? context.getCollection(type) : from.toString(); + } + + String getForeignFieldName(AggregationOperationContext context) { + + return from instanceof Class type ? context.getMappedFieldName(type, connectFrom.getTarget()) + : connectFrom.getTarget(); + } + @Override public String getOperator() { return "$graphLookup"; @@ -128,7 +138,7 @@ public ExposedFields getFields() { List fields = new ArrayList<>(2); fields.add(new ExposedField(as, true)); - if(depthField != null) { + if (depthField != null) { fields.add(new ExposedField(depthField, true)); } return ExposedFields.from(fields.toArray(new ExposedField[0])); @@ -146,6 +156,17 @@ public interface FromBuilder { * @return never {@literal null}. */ StartWithBuilder from(String collectionName); + + /** + * Use the given type to determine name of the foreign collection and map + * {@link ConnectFromBuilder#connectFrom(String)} against it to consider eventually present + * {@link org.springframework.data.mongodb.core.mapping.Field} annotations. + * + * @param type must not be {@literal null}. + * @return never {@literal null}. + * @since 4.2 + */ + StartWithBuilder from(Class type); } /** @@ -218,7 +239,7 @@ public interface ConnectToBuilder { static final class GraphLookupOperationFromBuilder implements FromBuilder, StartWithBuilder, ConnectFromBuilder, ConnectToBuilder { - private @Nullable String from; + private @Nullable Object from; private @Nullable List startWith; private @Nullable String connectFrom; @@ -231,6 +252,14 @@ public StartWithBuilder from(String collectionName) { return this; } + @Override + public StartWithBuilder from(Class type) { + + Assert.notNull(type, "Type must not be null"); + this.from = type; + return this; + } + @Override public ConnectFromBuilder startWith(String... fieldReferences) { @@ -321,7 +350,7 @@ public GraphLookupOperationBuilder connectTo(String fieldName) { */ public static final class GraphLookupOperationBuilder { - private final String from; + private final Object from; private final List startWith; private final Field connectFrom; private final Field connectTo; @@ -329,7 +358,7 @@ public static final class GraphLookupOperationBuilder { private @Nullable Field depthField; private @Nullable CriteriaDefinition restrictSearchWithMatch; - protected GraphLookupOperationBuilder(String from, List startWith, String connectFrom, + protected GraphLookupOperationBuilder(Object from, List startWith, String connectFrom, String connectTo) { this.from = from; diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/GraphLookupOperationUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/GraphLookupOperationUnitTests.java index 5280b603f6..6d85b31ffb 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/GraphLookupOperationUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/GraphLookupOperationUnitTests.java @@ -22,6 +22,7 @@ import org.bson.Document; import org.junit.jupiter.api.Test; import org.springframework.data.mongodb.core.Person; +import org.springframework.data.mongodb.core.mapping.Field; import org.springframework.data.mongodb.core.query.Criteria; /** @@ -34,7 +35,7 @@ public class GraphLookupOperationUnitTests { @Test // DATAMONGO-1551 public void rejectsNullFromCollection() { - assertThatIllegalArgumentException().isThrownBy(() -> GraphLookupOperation.builder().from(null)); + assertThatIllegalArgumentException().isThrownBy(() -> GraphLookupOperation.builder().from((String) null)); } @Test // DATAMONGO-1551 @@ -158,4 +159,59 @@ public void depthFieldShouldUseTargetFieldInsteadOfAlias() { assertThat(document).containsEntry("$graphLookup.depthField", "foo.bar"); } + + @Test // GH-4379 + void unmappedLookupWithFromExtractedFromType() { + + GraphLookupOperation graphLookupOperation = GraphLookupOperation.builder() // + .from(Employee.class) // + .startWith(LiteralOperators.Literal.asLiteral("hello")) // + .connectFrom("manager") // + .connectTo("name") // + .as("reportingHierarchy"); + + assertThat(graphLookupOperation.toDocument(Aggregation.DEFAULT_CONTEXT)).isEqualTo(""" + { $graphLookup: + { + from: "employee", + startWith : { $literal : "hello" }, + connectFromField: "manager", + connectToField: "name", + as: "reportingHierarchy" + } + }} + """); + } + + @Test // GH-4379 + void mappedLookupWithFromExtractedFromType() { + + GraphLookupOperation graphLookupOperation = GraphLookupOperation.builder() // + .from(Employee.class) // + .startWith(LiteralOperators.Literal.asLiteral("hello")) // + .connectFrom("manager") // + .connectTo("name") // + .as("reportingHierarchy"); + + assertThat(graphLookupOperation.toDocument(AggregationTestUtils.strict(Employee.class).ctx())).isEqualTo(""" + { $graphLookup: + { + from: "employees", + startWith : { $literal : "hello" }, + connectFromField: "reportsTo", + connectToField: "name", + as: "reportingHierarchy" + } + }} + """); + } + + @org.springframework.data.mongodb.core.mapping.Document("employees") + static class Employee { + + String id; + + @Field("reportsTo") + String manager; + } }