diff --git a/embedding-stores/langchain4j-community-neo4j/src/main/java/dev/langchain4j/community/store/embedding/neo4j/Neo4jEmbeddingStore.java b/embedding-stores/langchain4j-community-neo4j/src/main/java/dev/langchain4j/community/store/embedding/neo4j/Neo4jEmbeddingStore.java index 04e49a72..8ca519b2 100644 --- a/embedding-stores/langchain4j-community-neo4j/src/main/java/dev/langchain4j/community/store/embedding/neo4j/Neo4jEmbeddingStore.java +++ b/embedding-stores/langchain4j-community-neo4j/src/main/java/dev/langchain4j/community/store/embedding/neo4j/Neo4jEmbeddingStore.java @@ -30,6 +30,8 @@ import dev.langchain4j.store.embedding.EmbeddingSearchRequest; import dev.langchain4j.store.embedding.EmbeddingSearchResult; import dev.langchain4j.store.embedding.EmbeddingStore; +import dev.langchain4j.store.embedding.filter.Filter; +import java.util.AbstractMap; import java.util.Arrays; import java.util.Collection; import java.util.HashMap; @@ -284,62 +286,119 @@ public void removeAll(Collection ids) { } } + @Override + public void removeAll(Filter filter) { + ensureNotNull(filter, "filter"); + + final AbstractMap.SimpleEntry> filterEntry = new Neo4jFilterMapper().map(filter); + + try (var session = session()) { + String statement = String.format( + "CALL { MATCH (n:%1$s) WHERE n.%2$s IS NOT NULL AND size(n.%2$s) = toInteger(%3$s) AND %4$s DETACH DELETE n } IN TRANSACTIONS ", + this.sanitizedLabel, this.embeddingProperty, this.dimension, filterEntry.getKey()); + final Map params = filterEntry.getValue(); + session.run(statement, params); + } + } + @Override public EmbeddingSearchResult search(EmbeddingSearchRequest request) { var embeddingValue = Values.value(request.queryEmbedding().vector()); try (var session = session()) { - Map params = new HashMap<>(Map.of( - "indexName", - indexName, - "embeddingValue", - embeddingValue, - "minScore", - request.minScore(), - "maxResults", - request.maxResults())); - - String query = + final Filter filter = request.filter(); + if (filter == null) { + return getSearchResUsingVectorIndex(request, embeddingValue, session); + } + return getSearchResUsingVectorSimilarity(request, filter, embeddingValue, session); + } + } + + /* + Private methods + */ + private EmbeddingSearchResult getSearchResUsingVectorSimilarity( + EmbeddingSearchRequest request, Filter filter, Value embeddingValue, Session session) { + final AbstractMap.SimpleEntry> entry = new Neo4jFilterMapper().map(filter); + final String query = String.format( + """ + CYPHER runtime = parallel parallelRuntimeSupport=all + MATCH (n:%1$s) + WHERE n.%2$s IS NOT NULL AND size(n.%2$s) = toInteger(%3$s) AND %4$s + WITH n, vector.similarity.cosine(n.%2$s, %5$s) AS score + WHERE score >= $minScore + WITH n AS node, score + ORDER BY score DESC + LIMIT $maxResults + """ + + retrievalQuery, + this.sanitizedLabel, + this.embeddingProperty, + this.dimension, + entry.getKey(), + embeddingValue); + final Map params = entry.getValue(); + params.put("minScore", request.minScore()); + params.put("maxResults", request.maxResults()); + return getEmbeddingSearchResult(session, query, params); + } + + private EmbeddingSearchResult getSearchResUsingVectorIndex( + EmbeddingSearchRequest request, Value embeddingValue, Session session) { + Map params = new HashMap<>(Map.of( + "indexName", + indexName, + "embeddingValue", + embeddingValue, + "minScore", + request.minScore(), + "maxResults", + request.maxResults())); + + String query = + """ + CALL db.index.vector.queryNodes($indexName, $maxResults, $embeddingValue) + YIELD node, score + WHERE score >= $minScore + """ + + retrievalQuery; + + if (fullTextQuery != null) { + + query += """ - CALL db.index.vector.queryNodes($indexName, $maxResults, $embeddingValue) + \nUNION + CALL db.index.fulltext.queryNodes($fullTextIndexName, $fullTextQuery, {limit: $maxResults}) YIELD node, score WHERE score >= $minScore """ - + retrievalQuery; + + fullTextRetrievalQuery; - if (fullTextQuery != null) { + params.putAll(Map.of( + "fullTextIndexName", fullTextIndexName, + "fullTextQuery", fullTextQuery)); + } - query += - """ - \nUNION - CALL db.index.fulltext.queryNodes($fullTextIndexName, $fullTextQuery, {limit: $maxResults}) - YIELD node, score - WHERE score >= $minScore - """ - + fullTextRetrievalQuery; - - params.putAll(Map.of( - "fullTextIndexName", fullTextIndexName, - "fullTextQuery", fullTextQuery)); - } + final Set columns = getColumnNames(session, query); + final Set allowedColumn = Set.of(textProperty, embeddingProperty, idProperty, SCORE, METADATA); - final String finalQuery = query; - final Set columns = getColumnNames(session, query); - final Set allowedColumn = Set.of(textProperty, embeddingProperty, idProperty, SCORE, METADATA); + if (!allowedColumn.containsAll(columns) || columns.size() > allowedColumn.size()) { + throw new RuntimeException(COLUMNS_NOT_ALLOWED_ERR + columns); + } - if (!allowedColumn.containsAll(columns) || columns.size() > allowedColumn.size()) { - throw new RuntimeException(COLUMNS_NOT_ALLOWED_ERR + columns); - } + return getEmbeddingSearchResult(session, query, params); + } - List> matches = - session.executeRead(tx -> tx.run(finalQuery, params).list(item -> toEmbeddingMatch(this, item))); + private EmbeddingSearchResult getEmbeddingSearchResult( + Session session, String query, Map params) { + List> matches = + session.executeRead(tx -> tx.run(query, params).list(item -> toEmbeddingMatch(this, item))); - return new EmbeddingSearchResult<>(matches); - } + return new EmbeddingSearchResult<>(matches); } - private static Set getColumnNames(Session session, String query) { + private Set getColumnNames(Session session, String query) { // retrieve column names final List keys = session.run("EXPLAIN " + query).keys(); // when there are multiple variables with the same name, e.g. within a "UNION ALL" Neo4j adds a suffix @@ -348,10 +407,6 @@ private static Set getColumnNames(Session session, String query) { return keys.stream().map(i -> i.replaceFirst("@[0-9]+", "").trim()).collect(Collectors.toSet()); } - /* - Private methods - */ - private void addInternal(String id, Embedding embedding, TextSegment embedded) { addAll(singletonList(id), singletonList(embedding), embedded == null ? null : singletonList(embedded)); } @@ -436,8 +491,9 @@ private void createFullTextIndex() { try (var session = session()) { - final String query = "CREATE FULLTEXT INDEX %s IF NOT EXISTS FOR (n:%s) ON EACH [n.%s]" - .formatted(this.fullTextIndexName, this.sanitizedLabel, this.sanitizedIdProperty); + final String query = String.format( + "CREATE FULLTEXT INDEX %s IF NOT EXISTS FOR (n:%s) ON EACH [n.%s]", + this.fullTextIndexName, this.sanitizedLabel, this.sanitizedIdProperty); session.run(query).consume(); } } diff --git a/embedding-stores/langchain4j-community-neo4j/src/main/java/dev/langchain4j/community/store/embedding/neo4j/Neo4jFilterMapper.java b/embedding-stores/langchain4j-community-neo4j/src/main/java/dev/langchain4j/community/store/embedding/neo4j/Neo4jFilterMapper.java new file mode 100644 index 00000000..dcb6fe7b --- /dev/null +++ b/embedding-stores/langchain4j-community-neo4j/src/main/java/dev/langchain4j/community/store/embedding/neo4j/Neo4jFilterMapper.java @@ -0,0 +1,114 @@ +package dev.langchain4j.community.store.embedding.neo4j; + +import static org.neo4j.cypherdsl.support.schema_name.SchemaNames.sanitize; + +import dev.langchain4j.store.embedding.filter.Filter; +import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo; +import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThan; +import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThanOrEqualTo; +import dev.langchain4j.store.embedding.filter.comparison.IsIn; +import dev.langchain4j.store.embedding.filter.comparison.IsLessThan; +import dev.langchain4j.store.embedding.filter.comparison.IsLessThanOrEqualTo; +import dev.langchain4j.store.embedding.filter.comparison.IsNotEqualTo; +import dev.langchain4j.store.embedding.filter.comparison.IsNotIn; +import dev.langchain4j.store.embedding.filter.logical.And; +import dev.langchain4j.store.embedding.filter.logical.Not; +import dev.langchain4j.store.embedding.filter.logical.Or; +import java.util.AbstractMap; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; + +public class Neo4jFilterMapper { + + public static final String UNSUPPORTED_FILTER_TYPE_ERROR = "Unsupported filter type: "; + + public static class IncrementalKeyMap { + private final Map map = new ConcurrentHashMap<>(); + + private final AtomicInteger integer = new AtomicInteger(); + + public String put(Object value) { + String key = "param_" + integer.incrementAndGet(); + map.put(key, value); + return key; + } + + public Map getMap() { + return map; + } + } + + public Neo4jFilterMapper() {} + + final IncrementalKeyMap map = new IncrementalKeyMap(); + + AbstractMap.SimpleEntry> map(Filter filter) { + final String stringMapPair = getStringMapping(filter); + return new AbstractMap.SimpleEntry<>(stringMapPair, map.getMap()); + } + + private String getStringMapping(Filter filter) { + if (filter instanceof IsEqualTo item) { + return getOperation(item.key(), "=", item.comparisonValue()); + } else if (filter instanceof IsNotEqualTo item) { + return getOperation(item.key(), "<>", item.comparisonValue()); + } else if (filter instanceof IsGreaterThan item) { + return getOperation(item.key(), ">", item.comparisonValue()); + } else if (filter instanceof IsGreaterThanOrEqualTo item) { + return getOperation(item.key(), ">=", item.comparisonValue()); + } else if (filter instanceof IsLessThan item) { + return getOperation(item.key(), "<", item.comparisonValue()); + } else if (filter instanceof IsLessThanOrEqualTo item) { + return getOperation(item.key(), "<=", item.comparisonValue()); + } else if (filter instanceof IsIn item) { + return mapIn(item); + } else if (filter instanceof IsNotIn item) { + return mapNotIn(item); + } else if (filter instanceof And item) { + return mapAnd(item); + } else if (filter instanceof Not item) { + return mapNot(item); + } else if (filter instanceof Or item) { + return mapOr(item); + } else { + throw new UnsupportedOperationException( + UNSUPPORTED_FILTER_TYPE_ERROR + filter.getClass().getName()); + } + } + + private String getOperation(String key, String operator, Object value) { + // put ($param_N, ) entry map + final String param = map.put(value); + + String sanitizedKey = sanitize(key).orElseThrow(() -> { + String invalidSanitizeValue = String.format( + "The key %s, to assign to the operator %s and value %s, cannot be safely quoted", + key, operator, value); + return new RuntimeException(invalidSanitizeValue); + }); + + return String.format("n.%s %s $%s", sanitizedKey, operator, param); + } + + public String mapIn(IsIn filter) { + return getOperation(filter.key(), "IN", filter.comparisonValues()); + } + + public String mapNotIn(IsNotIn filter) { + final String inOperation = getOperation(filter.key(), "IN", filter.comparisonValues()); + return String.format("NOT (%s)", inOperation); + } + + private String mapAnd(And filter) { + return String.format("(%s) AND (%s)", getStringMapping(filter.left()), getStringMapping(filter.right())); + } + + private String mapOr(Or filter) { + return String.format("(%s) OR (%s)", getStringMapping(filter.left()), getStringMapping(filter.right())); + } + + private String mapNot(Not filter) { + return String.format("NOT (%s)", getStringMapping(filter.expression())); + } +} diff --git a/embedding-stores/langchain4j-community-neo4j/src/test/java/dev/langchain4j/community/store/embedding/neo4j/Neo4jEmbeddingStoreIT.java b/embedding-stores/langchain4j-community-neo4j/src/test/java/dev/langchain4j/community/store/embedding/neo4j/Neo4jEmbeddingStoreIT.java index 732c6f13..0a5e8701 100644 --- a/embedding-stores/langchain4j-community-neo4j/src/test/java/dev/langchain4j/community/store/embedding/neo4j/Neo4jEmbeddingStoreIT.java +++ b/embedding-stores/langchain4j-community-neo4j/src/test/java/dev/langchain4j/community/store/embedding/neo4j/Neo4jEmbeddingStoreIT.java @@ -6,6 +6,7 @@ import static dev.langchain4j.internal.RetryUtils.withRetry; import static dev.langchain4j.model.openai.OpenAiChatModelName.GPT_4_O_MINI; import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.Assert.fail; import dev.langchain4j.data.document.Document; import dev.langchain4j.data.document.DocumentParser; @@ -67,9 +68,10 @@ void should_emulate_issue_1306_case() { Document textDocument = extractor.transform(document); session.executeWrite(tx -> { - final String s = "CREATE FULLTEXT INDEX elizabeth_text IF NOT EXISTS FOR (e:%s) ON EACH [e.%s]" - .formatted(label, DEFAULT_ID_PROP); - tx.run(s).consume(); + final String query = String.format( + "CREATE FULLTEXT INDEX elizabeth_text IF NOT EXISTS FOR (e:%s) ON EACH [e.%s]", + label, DEFAULT_ID_PROP); + tx.run(query).consume(); return null; }); @@ -139,4 +141,35 @@ void should_emulate_issue_1306_case() { assertThat(matchesWithFullText).hasSizeGreaterThanOrEqualTo(1); matchesWithFullText.forEach(i -> assertThat(i.embeddingId()).contains("Elizabeth")); } + + @Test + void should_throw_error_if_another_index_name_with_different_label_exists() { + String metadataPrefix = "metadata."; + String idxName = "WillFail"; + + embeddingStore = Neo4jEmbeddingStore.builder() + .withBasicAuth(neo4jContainer.getBoltUrl(), USERNAME, ADMIN_PASSWORD) + .dimension(384) + .indexName(idxName) + .metadataPrefix(metadataPrefix) + .awaitIndexTimeout(20) + .build(); + + String secondLabel = "Second label"; + try { + embeddingStore = Neo4jEmbeddingStore.builder() + .withBasicAuth(neo4jContainer.getBoltUrl(), USERNAME, ADMIN_PASSWORD) + .dimension(384) + .label(secondLabel) + .indexName(idxName) + .metadataPrefix(metadataPrefix) + .build(); + fail("Should fail due to idx conflict"); + } catch (RuntimeException e) { + String errMsg = String.format( + "It's not possible to create an index for the label `%s` and the property `%s`", + secondLabel, DEFAULT_EMBEDDING_PROP); + assertThat(e.getMessage()).contains(errMsg); + } + } } diff --git a/embedding-stores/langchain4j-community-neo4j/src/test/java/dev/langchain4j/community/store/embedding/neo4j/Neo4jEmbeddingStoreRemovalIT.java b/embedding-stores/langchain4j-community-neo4j/src/test/java/dev/langchain4j/community/store/embedding/neo4j/Neo4jEmbeddingStoreRemovalIT.java index 9cff46a6..25c00be9 100644 --- a/embedding-stores/langchain4j-community-neo4j/src/test/java/dev/langchain4j/community/store/embedding/neo4j/Neo4jEmbeddingStoreRemovalIT.java +++ b/embedding-stores/langchain4j-community-neo4j/src/test/java/dev/langchain4j/community/store/embedding/neo4j/Neo4jEmbeddingStoreRemovalIT.java @@ -45,9 +45,4 @@ protected EmbeddingStore embeddingStore() { protected EmbeddingModel embeddingModel() { return embeddingModel; } - - @Override - protected boolean supportsRemoveAllByFilter() { - return false; - } } diff --git a/embedding-stores/langchain4j-community-neo4j/src/test/java/dev/langchain4j/community/store/embedding/neo4j/Neo4jEmbeddingStoreTest.java b/embedding-stores/langchain4j-community-neo4j/src/test/java/dev/langchain4j/community/store/embedding/neo4j/Neo4jEmbeddingStoreTest.java index b0d4bfe7..bc50cdb9 100644 --- a/embedding-stores/langchain4j-community-neo4j/src/test/java/dev/langchain4j/community/store/embedding/neo4j/Neo4jEmbeddingStoreTest.java +++ b/embedding-stores/langchain4j-community-neo4j/src/test/java/dev/langchain4j/community/store/embedding/neo4j/Neo4jEmbeddingStoreTest.java @@ -15,6 +15,13 @@ import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.store.embedding.EmbeddingMatch; import dev.langchain4j.store.embedding.EmbeddingSearchRequest; +import dev.langchain4j.store.embedding.EmbeddingSearchResult; +import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo; +import dev.langchain4j.store.embedding.filter.comparison.IsIn; +import dev.langchain4j.store.embedding.filter.comparison.IsNotEqualTo; +import dev.langchain4j.store.embedding.filter.logical.And; +import dev.langchain4j.store.embedding.filter.logical.Not; +import dev.langchain4j.store.embedding.filter.logical.Or; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -385,8 +392,9 @@ void should_add_embedding_and_fulltext_with_id() { final Embedding queryEmbedding = embeddingModel.embed(fullTextSearch).content(); session.executeWrite(tx -> { - final String query = "CREATE FULLTEXT INDEX %s IF NOT EXISTS FOR (e:%s) ON EACH [e.%s]" - .formatted(fullTextIndexName, label, DEFAULT_ID_PROP); + final String query = String.format( + "CREATE FULLTEXT INDEX %s IF NOT EXISTS FOR (e:%s) ON EACH [e.%s]", + fullTextIndexName, label, DEFAULT_ID_PROP); tx.run(query).consume(); return null; }); @@ -420,8 +428,55 @@ void should_add_embedding_and_fulltext_with_id() { }); } + @Test + void should_add_embedding_with_id_and_retrieve_with_and_without_prefilter() { + + final List segments = IntStream.range(0, 10) + .boxed() + .map(i -> { + if (i == 0) { + final Map metas = + Map.of("key1", "value1", "key2", 10, "key3", "3", "key4", "value4"); + final Metadata metadata = new Metadata(metas); + return TextSegment.from(randomUUID(), metadata); + } + return TextSegment.from(randomUUID()); + }) + .toList(); + + final List embeddings = embeddingModel.embedAll(segments).content(); + + embeddingStore.addAll(embeddings, segments); + + final And filter = new And( + new And(new IsEqualTo("key1", "value1"), new IsEqualTo("key2", "10")), + new Not(new Or(new IsIn("key3", asList("1", "2")), new IsNotEqualTo("key4", "value4")))); + + TextSegment segmentToSearch = TextSegment.from(randomUUID()); + Embedding embeddingToSearch = + embeddingModel.embed(segmentToSearch.text()).content(); + final EmbeddingSearchRequest requestWithFilter = EmbeddingSearchRequest.builder() + .maxResults(5) + .minScore(0.0) + .filter(filter) + .queryEmbedding(embeddingToSearch) + .build(); + final EmbeddingSearchResult searchWithFilter = embeddingStore.search(requestWithFilter); + final List> matchesWithFilter = searchWithFilter.matches(); + assertThat(matchesWithFilter).hasSize(1); + + final EmbeddingSearchRequest requestWithoutFilter = EmbeddingSearchRequest.builder() + .maxResults(5) + .minScore(0.0) + .queryEmbedding(embeddingToSearch) + .build(); + final EmbeddingSearchResult searchWithoutFilter = embeddingStore.search(requestWithoutFilter); + final List> matchesWithoutFilter = searchWithoutFilter.matches(); + assertThat(matchesWithoutFilter).hasSize(5); + } + private List>> getListRowsBatched(int numElements) { - return getListRowsBatched(numElements, (Neo4jEmbeddingStore) embeddingStore); + return getListRowsBatched(numElements, embeddingStore); } private List>> getListRowsBatched(int numElements, Neo4jEmbeddingStore embeddingStore) { diff --git a/embedding-stores/langchain4j-community-neo4j/src/test/java/dev/langchain4j/community/store/embedding/neo4j/Neo4jFilterMapperTest.java b/embedding-stores/langchain4j-community-neo4j/src/test/java/dev/langchain4j/community/store/embedding/neo4j/Neo4jFilterMapperTest.java new file mode 100644 index 00000000..003d142e --- /dev/null +++ b/embedding-stores/langchain4j-community-neo4j/src/test/java/dev/langchain4j/community/store/embedding/neo4j/Neo4jFilterMapperTest.java @@ -0,0 +1,158 @@ +package dev.langchain4j.community.store.embedding.neo4j; + +import static dev.langchain4j.community.store.embedding.neo4j.Neo4jFilterMapper.UNSUPPORTED_FILTER_TYPE_ERROR; +import static java.util.AbstractMap.SimpleEntry; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.fail; + +import dev.langchain4j.store.embedding.filter.Filter; +import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo; +import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThan; +import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThanOrEqualTo; +import dev.langchain4j.store.embedding.filter.comparison.IsIn; +import dev.langchain4j.store.embedding.filter.comparison.IsLessThan; +import dev.langchain4j.store.embedding.filter.comparison.IsLessThanOrEqualTo; +import dev.langchain4j.store.embedding.filter.comparison.IsNotEqualTo; +import dev.langchain4j.store.embedding.filter.comparison.IsNotIn; +import dev.langchain4j.store.embedding.filter.logical.And; +import dev.langchain4j.store.embedding.filter.logical.Not; +import dev.langchain4j.store.embedding.filter.logical.Or; +import java.util.Map; +import java.util.Set; +import org.junit.jupiter.api.Test; + +class Neo4jFilterMapperTest { + + @Test + void should_map_equal() { + IsEqualTo filter = new IsEqualTo("key", "value"); + final SimpleEntry> result = new Neo4jFilterMapper().map(filter); + assertThat(result.getKey()).isEqualTo("n.key = $param_1"); + assertThat(result.getValue()).isEqualTo(Map.of("param_1", "value")); + } + + @Test + void should_map_not_equal() { + IsNotEqualTo filter = new IsNotEqualTo("key", "value"); + final SimpleEntry> result = new Neo4jFilterMapper().map(filter); + assertThat(result.getKey()).isEqualTo("n.key <> $param_1"); + assertThat(result.getValue()).isEqualTo(Map.of("param_1", "value")); + } + + @Test + void should_map_is_greater_than() { + IsGreaterThan filter = new IsGreaterThan("key", 10); + final SimpleEntry> result = new Neo4jFilterMapper().map(filter); + assertThat(result.getKey()).isEqualTo("n.key > $param_1"); + assertThat(result.getValue()).isEqualTo(Map.of("param_1", 10)); + } + + @Test + void should_map_is_greater_than_or_equal_to() { + IsGreaterThanOrEqualTo filter = new IsGreaterThanOrEqualTo("key", 10); + final SimpleEntry> result = new Neo4jFilterMapper().map(filter); + assertThat(result.getKey()).isEqualTo("n.key >= $param_1"); + assertThat(result.getValue()).isEqualTo(Map.of("param_1", 10)); + } + + @Test + void should_map_is_less_than() { + IsLessThan filter = new IsLessThan("key", 10); + final SimpleEntry> result = new Neo4jFilterMapper().map(filter); + assertThat(result.getKey()).isEqualTo("n.key < $param_1"); + assertThat(result.getValue()).isEqualTo(Map.of("param_1", 10)); + } + + @Test + void should_map_is_less_than_or_equal_to() { + IsLessThanOrEqualTo filter = new IsLessThanOrEqualTo("key", 10); + final SimpleEntry> result = new Neo4jFilterMapper().map(filter); + assertThat(result.getKey()).isEqualTo("n.key <= $param_1"); + assertThat(result.getValue()).isEqualTo(Map.of("param_1", 10)); + } + + @Test + void should_map_is_in() { + final Set value = Set.of(1, 2, 3); + IsIn filter = new IsIn("key", value); + final SimpleEntry> result = new Neo4jFilterMapper().map(filter); + assertThat(result.getKey()).isEqualTo("n.key IN $param_1"); + assertThat(result.getValue()).isEqualTo(Map.of("param_1", value)); + } + + @Test + void should_map_is_not_in() { + final Set value = Set.of(1, 2, 3); + IsNotIn filter = new IsNotIn("key", value); + final SimpleEntry> result = new Neo4jFilterMapper().map(filter); + assertThat(result.getKey()).isEqualTo("NOT (n.key IN $param_1)"); + assertThat(result.getValue()).isEqualTo(Map.of("param_1", value)); + } + + @Test + void should_map_and() { + And filter = new And(new IsEqualTo("key1", "value1"), new IsEqualTo("key2", "value2")); + final SimpleEntry> result = new Neo4jFilterMapper().map(filter); + assertThat(result.getKey()).isEqualTo("(n.key1 = $param_1) AND (n.key2 = $param_2)"); + assertThat(result.getValue()).isEqualTo(Map.of("param_1", "value1", "param_2", "value2")); + } + + @Test + void should_map_or() { + Or filter = new Or(new IsEqualTo("key1", "value1"), new IsEqualTo("key2", "value2")); + final SimpleEntry> result = new Neo4jFilterMapper().map(filter); + assertThat(result.getKey()).isEqualTo("(n.key1 = $param_1) OR (n.key2 = $param_2)"); + assertThat(result.getValue()).isEqualTo(Map.of("param_1", "value1", "param_2", "value2")); + } + + @Test + void should_map_or_not_and() { + final Set valueKey3 = Set.of("1", "2"); + Or filter = new Or( + new And(new IsEqualTo("key1", "value1"), new IsGreaterThan("key2", "value2")), + new Not(new And(new IsIn("key3", valueKey3), new IsLessThan("key4", "value4")))); + final SimpleEntry> result = new Neo4jFilterMapper().map(filter); + assertThat(result.getKey()) + .isEqualTo( + "((n.key1 = $param_1) AND (n.key2 > $param_2)) OR (NOT ((n.key3 IN $param_3) AND (n.key4 < $param_4)))"); + assertThat(result.getValue()) + .isEqualTo(Map.of("param_1", "value1", "param_2", "value2", "param_3", valueKey3, "param_4", "value4")); + } + + @Test + void should_correctly_sanitize_key() { + IsEqualTo filter = new IsEqualTo("k\\ ` ey", "value"); + final SimpleEntry> result = new Neo4jFilterMapper().map(filter); + assertThat(result.getKey()).isEqualTo("n.`k\\ `` ey` = $param_1"); + assertThat(result.getValue()).isEqualTo(Map.of("param_1", "value")); + } + + @Test + void should_throws_unsupported_filter_error() { + MockFilter filter = new MockFilter(); + try { + new Neo4jFilterMapper().map(filter); + fail("Should fail due to unsupported filter"); + } catch (UnsupportedOperationException e) { + assertThat(e.getMessage()).contains(UNSUPPORTED_FILTER_TYPE_ERROR); + } + } + + private static class MockFilter implements Filter { + + @Override + public boolean test(final Object object) { + return false; + } + + @Override + public Filter and(final Filter filter) { + return Filter.super.and(filter); + } + + @Override + public Filter or(final Filter filter) { + return Filter.super.or(filter); + } + } +}