Skip to content

Support Metadata filtering with Neo4J #114

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.filter.Filter;
import java.util.AbstractMap;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
Expand Down Expand Up @@ -284,62 +286,119 @@ public void removeAll(Collection<String> ids) {
}
}

@Override
public void removeAll(Filter filter) {
ensureNotNull(filter, "filter");

final AbstractMap.SimpleEntry<String, Map<String, Object>> filterEntry = new Neo4jFilterMapper().map(filter);

try (var session = session()) {
String statement = String.format(
"CALL { MATCH (n:%1$s) WHERE n.%2$s IS NOT NULL AND size(n.%2$s) = toInteger(%3$s) AND %4$s DETACH DELETE n } IN TRANSACTIONS ",
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't seem to apply the filter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The filter is applied via filterEntry.getKey() in the String.format and filterEntry.getValue() in the params.
For example, if the Filter is IsEqualTo(key=type, comparisonValue=a) , the filterEntry.getKey() is "n.type = $param_1" and the filterEntry.getValue() is Map.of("param_1", "a") .

Therefore the result is:

 session.run( "CALL { MATCH  .... AND n.type = $param_1 DETACH DELETE n } IN TRANSACTIONS ", Map.of("param_1", "a") )

so that we can handle any neo4j data type

this.sanitizedLabel, this.embeddingProperty, this.dimension, filterEntry.getKey());
final Map<String, Object> params = filterEntry.getValue();
session.run(statement, params);
}
}

@Override
public EmbeddingSearchResult<TextSegment> search(EmbeddingSearchRequest request) {

var embeddingValue = Values.value(request.queryEmbedding().vector());

try (var session = session()) {
Map<String, Object> params = new HashMap<>(Map.of(
"indexName",
indexName,
"embeddingValue",
embeddingValue,
"minScore",
request.minScore(),
"maxResults",
request.maxResults()));

String query =
final Filter filter = request.filter();
if (filter == null) {
return getSearchResUsingVectorIndex(request, embeddingValue, session);
}
return getSearchResUsingVectorSimilarity(request, filter, embeddingValue, session);
}
}

/*
Private methods
*/
private EmbeddingSearchResult getSearchResUsingVectorSimilarity(
EmbeddingSearchRequest request, Filter filter, Value embeddingValue, Session session) {
final AbstractMap.SimpleEntry<String, Map<String, Object>> entry = new Neo4jFilterMapper().map(filter);
final String query = String.format(
"""
CYPHER runtime = parallel parallelRuntimeSupport=all
MATCH (n:%1$s)
WHERE n.%2$s IS NOT NULL AND size(n.%2$s) = toInteger(%3$s) AND %4$s
WITH n, vector.similarity.cosine(n.%2$s, %5$s) AS score
WHERE score >= $minScore
WITH n AS node, score
ORDER BY score DESC
LIMIT $maxResults
"""
+ retrievalQuery,
this.sanitizedLabel,
this.embeddingProperty,
this.dimension,
entry.getKey(),
embeddingValue);
final Map<String, Object> params = entry.getValue();
params.put("minScore", request.minScore());
params.put("maxResults", request.maxResults());
return getEmbeddingSearchResult(session, query, params);
}

private EmbeddingSearchResult<TextSegment> getSearchResUsingVectorIndex(
EmbeddingSearchRequest request, Value embeddingValue, Session session) {
Map<String, Object> params = new HashMap<>(Map.of(
"indexName",
indexName,
"embeddingValue",
embeddingValue,
"minScore",
request.minScore(),
"maxResults",
request.maxResults()));

String query =
"""
CALL db.index.vector.queryNodes($indexName, $maxResults, $embeddingValue)
YIELD node, score
WHERE score >= $minScore
"""
+ retrievalQuery;

if (fullTextQuery != null) {

query +=
"""
CALL db.index.vector.queryNodes($indexName, $maxResults, $embeddingValue)
\nUNION
CALL db.index.fulltext.queryNodes($fullTextIndexName, $fullTextQuery, {limit: $maxResults})
YIELD node, score
WHERE score >= $minScore
"""
+ retrievalQuery;
+ fullTextRetrievalQuery;

if (fullTextQuery != null) {
params.putAll(Map.of(
"fullTextIndexName", fullTextIndexName,
"fullTextQuery", fullTextQuery));
}

query +=
"""
\nUNION
CALL db.index.fulltext.queryNodes($fullTextIndexName, $fullTextQuery, {limit: $maxResults})
YIELD node, score
WHERE score >= $minScore
"""
+ fullTextRetrievalQuery;

params.putAll(Map.of(
"fullTextIndexName", fullTextIndexName,
"fullTextQuery", fullTextQuery));
}
final Set<String> columns = getColumnNames(session, query);
final Set<Object> allowedColumn = Set.of(textProperty, embeddingProperty, idProperty, SCORE, METADATA);

final String finalQuery = query;
final Set<String> columns = getColumnNames(session, query);
final Set<Object> allowedColumn = Set.of(textProperty, embeddingProperty, idProperty, SCORE, METADATA);
if (!allowedColumn.containsAll(columns) || columns.size() > allowedColumn.size()) {
throw new RuntimeException(COLUMNS_NOT_ALLOWED_ERR + columns);
}

if (!allowedColumn.containsAll(columns) || columns.size() > allowedColumn.size()) {
throw new RuntimeException(COLUMNS_NOT_ALLOWED_ERR + columns);
}
return getEmbeddingSearchResult(session, query, params);
}

List<EmbeddingMatch<TextSegment>> matches =
session.executeRead(tx -> tx.run(finalQuery, params).list(item -> toEmbeddingMatch(this, item)));
private EmbeddingSearchResult<TextSegment> getEmbeddingSearchResult(
Session session, String query, Map<String, Object> params) {
List<EmbeddingMatch<TextSegment>> matches =
session.executeRead(tx -> tx.run(query, params).list(item -> toEmbeddingMatch(this, item)));

return new EmbeddingSearchResult<>(matches);
}
return new EmbeddingSearchResult<>(matches);
}

private static Set<String> getColumnNames(Session session, String query) {
private Set<String> getColumnNames(Session session, String query) {
// retrieve column names
final List<String> keys = session.run("EXPLAIN " + query).keys();
// when there are multiple variables with the same name, e.g. within a "UNION ALL" Neo4j adds a suffix
Expand All @@ -348,10 +407,6 @@ private static Set<String> getColumnNames(Session session, String query) {
return keys.stream().map(i -> i.replaceFirst("@[0-9]+", "").trim()).collect(Collectors.toSet());
}

/*
Private methods
*/

private void addInternal(String id, Embedding embedding, TextSegment embedded) {
addAll(singletonList(id), singletonList(embedding), embedded == null ? null : singletonList(embedded));
}
Expand Down Expand Up @@ -436,8 +491,9 @@ private void createFullTextIndex() {

try (var session = session()) {

final String query = "CREATE FULLTEXT INDEX %s IF NOT EXISTS FOR (n:%s) ON EACH [n.%s]"
.formatted(this.fullTextIndexName, this.sanitizedLabel, this.sanitizedIdProperty);
final String query = String.format(
"CREATE FULLTEXT INDEX %s IF NOT EXISTS FOR (n:%s) ON EACH [n.%s]",
this.fullTextIndexName, this.sanitizedLabel, this.sanitizedIdProperty);
session.run(query).consume();
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package dev.langchain4j.community.store.embedding.neo4j;

import static org.neo4j.cypherdsl.support.schema_name.SchemaNames.sanitize;

import dev.langchain4j.store.embedding.filter.Filter;
import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo;
import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThan;
import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThanOrEqualTo;
import dev.langchain4j.store.embedding.filter.comparison.IsIn;
import dev.langchain4j.store.embedding.filter.comparison.IsLessThan;
import dev.langchain4j.store.embedding.filter.comparison.IsLessThanOrEqualTo;
import dev.langchain4j.store.embedding.filter.comparison.IsNotEqualTo;
import dev.langchain4j.store.embedding.filter.comparison.IsNotIn;
import dev.langchain4j.store.embedding.filter.logical.And;
import dev.langchain4j.store.embedding.filter.logical.Not;
import dev.langchain4j.store.embedding.filter.logical.Or;
import java.util.AbstractMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;

public class Neo4jFilterMapper {

public static final String UNSUPPORTED_FILTER_TYPE_ERROR = "Unsupported filter type: ";

public static class IncrementalKeyMap {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a list might be better

private final Map<String, Object> map = new ConcurrentHashMap<>();

private final AtomicInteger integer = new AtomicInteger();

public String put(Object value) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like it should use AtomicInteger to keep thread-safe. WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed to private final AtomicInteger integer = new AtomicInteger(); 👍

String key = "param_" + integer.incrementAndGet();
map.put(key, value);
return key;
}

public Map<String, Object> getMap() {
return map;
}
}

public Neo4jFilterMapper() {}

final IncrementalKeyMap map = new IncrementalKeyMap();

AbstractMap.SimpleEntry<String, Map<String, Object>> map(Filter filter) {
final String stringMapPair = getStringMapping(filter);
return new AbstractMap.SimpleEntry<>(stringMapPair, map.getMap());
}

private String getStringMapping(Filter filter) {
if (filter instanceof IsEqualTo item) {
return getOperation(item.key(), "=", item.comparisonValue());
} else if (filter instanceof IsNotEqualTo item) {
return getOperation(item.key(), "<>", item.comparisonValue());
} else if (filter instanceof IsGreaterThan item) {
return getOperation(item.key(), ">", item.comparisonValue());
} else if (filter instanceof IsGreaterThanOrEqualTo item) {
return getOperation(item.key(), ">=", item.comparisonValue());
} else if (filter instanceof IsLessThan item) {
return getOperation(item.key(), "<", item.comparisonValue());
} else if (filter instanceof IsLessThanOrEqualTo item) {
return getOperation(item.key(), "<=", item.comparisonValue());
} else if (filter instanceof IsIn item) {
return mapIn(item);
} else if (filter instanceof IsNotIn item) {
return mapNotIn(item);
} else if (filter instanceof And item) {
return mapAnd(item);
} else if (filter instanceof Not item) {
return mapNot(item);
} else if (filter instanceof Or item) {
return mapOr(item);
} else {
throw new UnsupportedOperationException(
UNSUPPORTED_FILTER_TYPE_ERROR + filter.getClass().getName());
}
}

private String getOperation(String key, String operator, Object value) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

usually this is an enum of operators which is then also able to format it
and a record(property, Operator, value) and then a list of that

but I think Cypher DSL has this stuff out of the box.

// put ($param_N, <value>) entry map
final String param = map.put(value);

String sanitizedKey = sanitize(key).orElseThrow(() -> {
String invalidSanitizeValue = String.format(
"The key %s, to assign to the operator %s and value %s, cannot be safely quoted",
key, operator, value);
return new RuntimeException(invalidSanitizeValue);
});

return String.format("n.%s %s $%s", sanitizedKey, operator, param);
}

public String mapIn(IsIn filter) {
return getOperation(filter.key(), "IN", filter.comparisonValues());
}

public String mapNotIn(IsNotIn filter) {
final String inOperation = getOperation(filter.key(), "IN", filter.comparisonValues());
return String.format("NOT (%s)", inOperation);
}

private String mapAnd(And filter) {
return String.format("(%s) AND (%s)", getStringMapping(filter.left()), getStringMapping(filter.right()));
}

private String mapOr(Or filter) {
return String.format("(%s) OR (%s)", getStringMapping(filter.left()), getStringMapping(filter.right()));
}

private String mapNot(Not filter) {
return String.format("NOT (%s)", getStringMapping(filter.expression()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.model.openai.OpenAiChatModelName.GPT_4_O_MINI;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.Assert.fail;

import dev.langchain4j.data.document.Document;
import dev.langchain4j.data.document.DocumentParser;
Expand Down Expand Up @@ -67,9 +68,10 @@ void should_emulate_issue_1306_case() {
Document textDocument = extractor.transform(document);

session.executeWrite(tx -> {
final String s = "CREATE FULLTEXT INDEX elizabeth_text IF NOT EXISTS FOR (e:%s) ON EACH [e.%s]"
.formatted(label, DEFAULT_ID_PROP);
tx.run(s).consume();
final String query = String.format(
"CREATE FULLTEXT INDEX elizabeth_text IF NOT EXISTS FOR (e:%s) ON EACH [e.%s]",
label, DEFAULT_ID_PROP);
tx.run(query).consume();
return null;
});

Expand Down Expand Up @@ -139,4 +141,35 @@ void should_emulate_issue_1306_case() {
assertThat(matchesWithFullText).hasSizeGreaterThanOrEqualTo(1);
matchesWithFullText.forEach(i -> assertThat(i.embeddingId()).contains("Elizabeth"));
}

@Test
void should_throw_error_if_another_index_name_with_different_label_exists() {
String metadataPrefix = "metadata.";
String idxName = "WillFail";

embeddingStore = Neo4jEmbeddingStore.builder()
.withBasicAuth(neo4jContainer.getBoltUrl(), USERNAME, ADMIN_PASSWORD)
.dimension(384)
.indexName(idxName)
.metadataPrefix(metadataPrefix)
.awaitIndexTimeout(20)
.build();

String secondLabel = "Second label";
try {
embeddingStore = Neo4jEmbeddingStore.builder()
.withBasicAuth(neo4jContainer.getBoltUrl(), USERNAME, ADMIN_PASSWORD)
.dimension(384)
.label(secondLabel)
.indexName(idxName)
.metadataPrefix(metadataPrefix)
.build();
fail("Should fail due to idx conflict");
} catch (RuntimeException e) {
String errMsg = String.format(
"It's not possible to create an index for the label `%s` and the property `%s`",
secondLabel, DEFAULT_EMBEDDING_PROP);
assertThat(e.getMessage()).contains(errMsg);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,4 @@ protected EmbeddingStore<TextSegment> embeddingStore() {
protected EmbeddingModel embeddingModel() {
return embeddingModel;
}

@Override
protected boolean supportsRemoveAllByFilter() {
return false;
}
}
Loading