Skip to content

Commit 43eff19

Browse files
vga91Martin7-1
andauthoredMar 25, 2025··
Migrate neo4j retriever (#109)
* Migrating neo4j-content-retriever * changed module name * changes reviews * added Neo4jContentRetriever as a deprecated class * split mock and real chatModel tests * format --------- Co-authored-by: Martin7-1 <yi.zheng.se@gmail.com>
1 parent d1b6ad2 commit 43eff19

File tree

10 files changed

+612
-0
lines changed

10 files changed

+612
-0
lines changed
 
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
3+
<modelVersion>4.0.0</modelVersion>
4+
<parent>
5+
<groupId>dev.langchain4j</groupId>
6+
<artifactId>langchain4j-community</artifactId>
7+
<version>1.0.0-beta3-SNAPSHOT</version>
8+
<relativePath>../../pom.xml</relativePath>
9+
</parent>
10+
11+
<artifactId>langchain4j-community-neo4j-retriever</artifactId>
12+
<packaging>jar</packaging>
13+
<name>LangChain4j :: Community :: Content Retriever :: Neo4j</name>
14+
15+
<dependencies>
16+
17+
<dependency>
18+
<groupId>dev.langchain4j</groupId>
19+
<artifactId>langchain4j-core</artifactId>
20+
<version>${project.version}</version>
21+
</dependency>
22+
23+
<dependency>
24+
<groupId>org.slf4j</groupId>
25+
<artifactId>slf4j-api</artifactId>
26+
</dependency>
27+
28+
<dependency>
29+
<groupId>org.neo4j.driver</groupId>
30+
<artifactId>neo4j-java-driver</artifactId>
31+
<version>5.26.0</version>
32+
</dependency>
33+
34+
<!-- test dependencies -->
35+
<dependency>
36+
<groupId>org.testcontainers</groupId>
37+
<artifactId>junit-jupiter</artifactId>
38+
<scope>test</scope>
39+
</dependency>
40+
41+
<dependency>
42+
<groupId>org.testcontainers</groupId>
43+
<artifactId>neo4j</artifactId>
44+
<scope>test</scope>
45+
</dependency>
46+
47+
<dependency>
48+
<groupId>org.junit.jupiter</groupId>
49+
<artifactId>junit-jupiter</artifactId>
50+
<scope>test</scope>
51+
</dependency>
52+
53+
<dependency>
54+
<groupId>dev.langchain4j</groupId>
55+
<artifactId>langchain4j-embeddings-all-minilm-l6-v2-q</artifactId>
56+
<scope>test</scope>
57+
</dependency>
58+
59+
<dependency>
60+
<groupId>org.assertj</groupId>
61+
<artifactId>assertj-core</artifactId>
62+
<scope>test</scope>
63+
</dependency>
64+
65+
<dependency>
66+
<groupId>dev.langchain4j</groupId>
67+
<artifactId>langchain4j-open-ai</artifactId>
68+
<version>${project.version}</version>
69+
<scope>test</scope>
70+
</dependency>
71+
72+
</dependencies>
73+
74+
</project>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
package dev.langchain4j.rag.content.retriever.neo4j;
2+
3+
import dev.langchain4j.model.chat.ChatLanguageModel;
4+
import dev.langchain4j.model.input.PromptTemplate;
5+
6+
/**
7+
* @deprecated use {@link Neo4jText2CypherRetriever} instead
8+
*/
9+
@Deprecated(forRemoval = true)
10+
public class Neo4jContentRetriever extends Neo4jText2CypherRetriever {
11+
12+
public Neo4jContentRetriever(Neo4jGraph graph, ChatLanguageModel chatLanguageModel, PromptTemplate promptTemplate) {
13+
super(graph, chatLanguageModel, promptTemplate);
14+
}
15+
16+
public static Builder builder() {
17+
return new Builder();
18+
}
19+
20+
public static class Builder extends Neo4jText2CypherRetriever.Builder<Builder> {
21+
22+
@Override
23+
public Neo4jContentRetriever build() {
24+
return new Neo4jContentRetriever(graph, chatLanguageModel, promptTemplate);
25+
}
26+
}
27+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package dev.langchain4j.rag.content.retriever.neo4j;
2+
3+
public class Neo4jException extends RuntimeException {
4+
5+
public Neo4jException(String message, Throwable cause) {
6+
super(message, cause);
7+
}
8+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
package dev.langchain4j.rag.content.retriever.neo4j;
2+
3+
import dev.langchain4j.internal.ValidationUtils;
4+
import java.util.List;
5+
import java.util.Map;
6+
import java.util.stream.Collectors;
7+
import org.neo4j.driver.Driver;
8+
import org.neo4j.driver.Record;
9+
import org.neo4j.driver.Session;
10+
import org.neo4j.driver.Value;
11+
import org.neo4j.driver.exceptions.ClientException;
12+
import org.neo4j.driver.summary.ResultSummary;
13+
14+
public class Neo4jGraph implements AutoCloseable {
15+
16+
public static class Builder {
17+
private Driver driver;
18+
19+
/**
20+
* @param driver the {@link Driver} (required)
21+
*/
22+
Builder driver(Driver driver) {
23+
this.driver = driver;
24+
return this;
25+
}
26+
27+
Neo4jGraph build() {
28+
return new Neo4jGraph(driver);
29+
}
30+
}
31+
32+
private static final String NODE_PROPERTIES_QUERY =
33+
"""
34+
CALL apoc.meta.data()
35+
YIELD label, other, elementType, type, property
36+
WHERE NOT type = "RELATIONSHIP" AND elementType = "node"
37+
WITH label AS nodeLabels, collect({property:property, type:type}) AS properties
38+
RETURN {labels: nodeLabels, properties: properties} AS output
39+
""";
40+
41+
private static final String REL_PROPERTIES_QUERY =
42+
"""
43+
CALL apoc.meta.data()
44+
YIELD label, other, elementType, type, property
45+
WHERE NOT type = "RELATIONSHIP" AND elementType = "relationship"
46+
WITH label AS nodeLabels, collect({property:property, type:type}) AS properties
47+
RETURN {type: nodeLabels, properties: properties} AS output
48+
""";
49+
50+
private static final String RELATIONSHIPS_QUERY =
51+
"""
52+
CALL apoc.meta.data()
53+
YIELD label, other, elementType, type, property
54+
WHERE type = "RELATIONSHIP" AND elementType = "node"
55+
UNWIND other AS other_node
56+
RETURN {start: label, type: property, end: toString(other_node)} AS output
57+
""";
58+
59+
private final Driver driver;
60+
61+
private String schema;
62+
63+
public Neo4jGraph(final Driver driver) {
64+
65+
this.driver = ValidationUtils.ensureNotNull(driver, "driver");
66+
this.driver.verifyConnectivity();
67+
try {
68+
refreshSchema();
69+
} catch (ClientException e) {
70+
if ("Neo.ClientError.Procedure.ProcedureNotFound".equals(e.code())) {
71+
throw new Neo4jException("Please ensure the APOC plugin is installed in Neo4j", e);
72+
}
73+
throw e;
74+
}
75+
}
76+
77+
public String getSchema() {
78+
return schema;
79+
}
80+
81+
static Builder builder() {
82+
return new Builder();
83+
}
84+
85+
public ResultSummary executeWrite(String queryString) {
86+
87+
try (Session session = this.driver.session()) {
88+
return session.executeWrite(tx -> tx.run(queryString).consume());
89+
} catch (ClientException e) {
90+
throw new Neo4jException("Error executing query: " + queryString, e);
91+
}
92+
}
93+
94+
public List<Record> executeRead(String queryString) {
95+
96+
return this.driver.executableQuery(queryString).execute().records();
97+
}
98+
99+
public void refreshSchema() {
100+
101+
List<String> nodeProperties = formatNodeProperties(executeRead(NODE_PROPERTIES_QUERY));
102+
List<String> relationshipProperties = formatRelationshipProperties(executeRead(REL_PROPERTIES_QUERY));
103+
List<String> relationships = formatRelationships(executeRead(RELATIONSHIPS_QUERY));
104+
105+
this.schema = "Node properties are the following:\n" + String.join("\n", nodeProperties)
106+
+ "\n\n" + "Relationship properties are the following:\n"
107+
+ String.join("\n", relationshipProperties)
108+
+ "\n\n" + "The relationships are the following:\n"
109+
+ String.join("\n", relationships);
110+
}
111+
112+
private List<String> formatNodeProperties(List<Record> records) {
113+
114+
return records.stream()
115+
.map(this::getOutput)
116+
.map(r -> String.format(
117+
"%s %s",
118+
r.asMap().get("labels"), formatMap(r.get("properties").asList(Value::asMap))))
119+
.toList();
120+
}
121+
122+
private List<String> formatRelationshipProperties(List<Record> records) {
123+
124+
return records.stream()
125+
.map(this::getOutput)
126+
.map(r -> String.format(
127+
"%s %s", r.get("type"), formatMap(r.get("properties").asList(Value::asMap))))
128+
.toList();
129+
}
130+
131+
private List<String> formatRelationships(List<Record> records) {
132+
133+
return records.stream()
134+
.map(r -> getOutput(r).asMap())
135+
.map(r -> String.format("(:%s)-[:%s]->(:%s)", r.get("start"), r.get("type"), r.get("end")))
136+
.toList();
137+
}
138+
139+
private Value getOutput(Record record) {
140+
141+
return record.get("output");
142+
}
143+
144+
private String formatMap(List<Map<String, Object>> properties) {
145+
146+
return properties.stream()
147+
.map(prop -> prop.get("property") + ":" + prop.get("type"))
148+
.collect(Collectors.joining(", ", "{", "}"));
149+
}
150+
151+
@Override
152+
public void close() {
153+
154+
this.driver.close();
155+
}
156+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
package dev.langchain4j.rag.content.retriever.neo4j;
2+
3+
import static dev.langchain4j.internal.Utils.getOrDefault;
4+
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
5+
6+
import dev.langchain4j.model.chat.ChatLanguageModel;
7+
import dev.langchain4j.model.input.Prompt;
8+
import dev.langchain4j.model.input.PromptTemplate;
9+
import dev.langchain4j.rag.content.Content;
10+
import dev.langchain4j.rag.content.retriever.ContentRetriever;
11+
import dev.langchain4j.rag.query.Query;
12+
import java.util.List;
13+
import java.util.Map;
14+
import java.util.regex.Matcher;
15+
import java.util.regex.Pattern;
16+
import org.neo4j.driver.Record;
17+
import org.neo4j.driver.types.Type;
18+
import org.neo4j.driver.types.TypeSystem;
19+
20+
public class Neo4jText2CypherRetriever implements ContentRetriever {
21+
22+
public static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = PromptTemplate.from(
23+
"""
24+
Based on the Neo4j graph schema below, write a Cypher query that would answer the user's question:
25+
{{schema}}
26+
27+
Question: {{question}}
28+
Cypher query:
29+
""");
30+
31+
private static final Pattern BACKTICKS_PATTERN = Pattern.compile("```(.*?)```", Pattern.MULTILINE | Pattern.DOTALL);
32+
private static final Type NODE = TypeSystem.getDefault().NODE();
33+
private static final Type RELATIONSHIP = TypeSystem.getDefault().RELATIONSHIP();
34+
private static final Type PATH = TypeSystem.getDefault().PATH();
35+
36+
private final Neo4jGraph graph;
37+
38+
private final ChatLanguageModel chatLanguageModel;
39+
40+
private final PromptTemplate promptTemplate;
41+
42+
public Neo4jText2CypherRetriever(
43+
Neo4jGraph graph, ChatLanguageModel chatLanguageModel, PromptTemplate promptTemplate) {
44+
45+
this.graph = ensureNotNull(graph, "graph");
46+
this.chatLanguageModel = ensureNotNull(chatLanguageModel, "chatLanguageModel");
47+
this.promptTemplate = getOrDefault(promptTemplate, DEFAULT_PROMPT_TEMPLATE);
48+
}
49+
50+
public static Builder builder() {
51+
return new Builder();
52+
}
53+
54+
/*
55+
Getter methods
56+
*/
57+
public Neo4jGraph getGraph() {
58+
return graph;
59+
}
60+
61+
public ChatLanguageModel getChatLanguageModel() {
62+
return chatLanguageModel;
63+
}
64+
65+
public PromptTemplate getPromptTemplate() {
66+
return promptTemplate;
67+
}
68+
69+
@Override
70+
public List<Content> retrieve(Query query) {
71+
72+
String question = query.text();
73+
String schema = graph.getSchema();
74+
String cypherQuery = generateCypherQuery(schema, question);
75+
List<String> response = executeQuery(cypherQuery);
76+
return response.stream().map(Content::from).toList();
77+
}
78+
79+
private String generateCypherQuery(String schema, String question) {
80+
81+
Prompt cypherPrompt = promptTemplate.apply(Map.of("schema", schema, "question", question));
82+
String cypherQuery = chatLanguageModel.chat(cypherPrompt.text());
83+
Matcher matcher = BACKTICKS_PATTERN.matcher(cypherQuery);
84+
if (matcher.find()) {
85+
return matcher.group(1);
86+
}
87+
return cypherQuery;
88+
}
89+
90+
private List<String> executeQuery(String cypherQuery) {
91+
92+
List<Record> records = graph.executeRead(cypherQuery);
93+
return records.stream()
94+
.flatMap(r -> r.values().stream())
95+
.map(value -> {
96+
final boolean isEntity =
97+
NODE.isTypeOf(value) || RELATIONSHIP.isTypeOf(value) || PATH.isTypeOf(value);
98+
if (isEntity) {
99+
return value.asMap().toString();
100+
}
101+
return value.toString();
102+
})
103+
.toList();
104+
}
105+
106+
public static class Builder<T extends Builder<T>> {
107+
108+
protected Neo4jGraph graph;
109+
protected ChatLanguageModel chatLanguageModel;
110+
protected PromptTemplate promptTemplate;
111+
112+
/**
113+
* @param graph the {@link Neo4jGraph} (required)
114+
*/
115+
public T graph(Neo4jGraph graph) {
116+
this.graph = graph;
117+
return self();
118+
}
119+
120+
/**
121+
* @param chatLanguageModel the {@link ChatLanguageModel} (required)
122+
*/
123+
public T chatLanguageModel(ChatLanguageModel chatLanguageModel) {
124+
this.chatLanguageModel = chatLanguageModel;
125+
return self();
126+
}
127+
128+
/**
129+
* @param promptTemplate the {@link PromptTemplate} (optional, default is {@link Neo4jText2CypherRetriever#DEFAULT_PROMPT_TEMPLATE})
130+
*/
131+
public T promptTemplate(PromptTemplate promptTemplate) {
132+
this.promptTemplate = promptTemplate;
133+
return self();
134+
}
135+
136+
protected T self() {
137+
return (T) this;
138+
}
139+
140+
Neo4jText2CypherRetriever build() {
141+
return new Neo4jText2CypherRetriever(graph, chatLanguageModel, promptTemplate);
142+
}
143+
}
144+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package dev.langchain4j.rag.content.retriever.neo4j;
2+
3+
import static dev.langchain4j.model.openai.OpenAiChatModelName.GPT_4_O_MINI;
4+
import static org.assertj.core.api.Assertions.assertThat;
5+
6+
import dev.langchain4j.model.chat.ChatLanguageModel;
7+
import dev.langchain4j.model.openai.OpenAiChatModel;
8+
import dev.langchain4j.rag.content.Content;
9+
import dev.langchain4j.rag.query.Query;
10+
import java.util.List;
11+
import org.junit.jupiter.api.Test;
12+
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
13+
14+
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
15+
class Neo4JText2CypherRetrieverIT extends Neo4jText2CypherRetrieverBaseTest {
16+
17+
@Test
18+
void shouldRetrieveContentWhenQueryIsValidAndOpenAiChatModelIsUsed() {
19+
20+
// With
21+
ChatLanguageModel openAiChatModel = OpenAiChatModel.builder()
22+
.baseUrl(System.getenv("OPENAI_BASE_URL"))
23+
.apiKey(System.getenv("OPENAI_API_KEY"))
24+
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
25+
.modelName(GPT_4_O_MINI)
26+
.logRequests(true)
27+
.logResponses(true)
28+
.build();
29+
30+
Neo4jText2CypherRetriever neo4jContentRetriever = Neo4jText2CypherRetriever.builder()
31+
.graph(graph)
32+
.chatLanguageModel(openAiChatModel)
33+
.build();
34+
35+
// Given
36+
Query query = new Query("Who is the author of the book 'Dune'?");
37+
38+
// When
39+
List<Content> contents = neo4jContentRetriever.retrieve(query);
40+
41+
// Then
42+
assertThat(contents).hasSize(1);
43+
}
44+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
package dev.langchain4j.rag.content.retriever.neo4j;
2+
3+
import static org.assertj.core.api.Assertions.assertThat;
4+
import static org.mockito.ArgumentMatchers.anyString;
5+
import static org.mockito.Mockito.when;
6+
7+
import dev.langchain4j.model.chat.ChatLanguageModel;
8+
import dev.langchain4j.rag.content.Content;
9+
import dev.langchain4j.rag.query.Query;
10+
import java.util.List;
11+
import org.junit.jupiter.api.BeforeEach;
12+
import org.junit.jupiter.api.Test;
13+
import org.junit.jupiter.api.extension.ExtendWith;
14+
import org.mockito.Mock;
15+
import org.mockito.junit.jupiter.MockitoExtension;
16+
17+
@ExtendWith(MockitoExtension.class)
18+
public class Neo4JText2CypherRetrieverTest extends Neo4jText2CypherRetrieverBaseTest {
19+
20+
private Neo4jText2CypherRetriever retriever;
21+
22+
@Mock
23+
private ChatLanguageModel chatLanguageModel;
24+
25+
@BeforeEach
26+
void beforeEach() {
27+
super.beforeEach();
28+
29+
retriever = Neo4jText2CypherRetriever.builder()
30+
.graph(graph)
31+
.chatLanguageModel(chatLanguageModel)
32+
.build();
33+
}
34+
35+
@Test
36+
void shouldRetrieveContentWhenQueryIsValid() {
37+
// Given
38+
Query query = new Query("Who is the author of the book 'Dune'?");
39+
when(chatLanguageModel.chat(anyString()))
40+
.thenReturn("MATCH(book:Book {title: 'Dune'})<-[:WROTE]-(author:Person) RETURN author.name AS output");
41+
42+
// When
43+
List<Content> contents = retriever.retrieve(query);
44+
45+
// Then
46+
assertThat(contents).hasSize(1);
47+
}
48+
49+
@Test
50+
void shouldRetrieveContentWhenQueryIsValidWithDeprecatedClass() {
51+
// Given
52+
Query query = new Query("Who is the author of the book 'Dune'?");
53+
when(chatLanguageModel.chat(anyString()))
54+
.thenReturn("MATCH(book:Book {title: 'Dune'})<-[:WROTE]-(author:Person) RETURN author.name AS output");
55+
56+
Neo4jContentRetriever retriever = Neo4jContentRetriever.builder()
57+
.graph(graph)
58+
.chatLanguageModel(chatLanguageModel)
59+
.build();
60+
61+
// When
62+
List<Content> contents = retriever.retrieve(query);
63+
64+
// Then
65+
assertThat(contents).hasSize(1);
66+
}
67+
68+
@Test
69+
void shouldRetrieveContentWhenQueryIsValidAndResponseHasBackticks() {
70+
// Given
71+
Query query = new Query("Who is the author of the book 'Dune'?");
72+
when(chatLanguageModel.chat(anyString()))
73+
.thenReturn(
74+
"```MATCH(book:Book {title: 'Dune'})<-[:WROTE]-(author:Person) RETURN author.name AS output```");
75+
76+
// When
77+
List<Content> contents = retriever.retrieve(query);
78+
79+
// Then
80+
assertThat(contents).hasSize(1);
81+
}
82+
83+
@Test
84+
void shouldReturnEmptyListWhenQueryIsInvalid() {
85+
// Given
86+
Query query = new Query("Who is the author of the movie 'Dune'?");
87+
when(chatLanguageModel.chat(anyString()))
88+
.thenReturn(
89+
"MATCH(movie:Movie {title: 'Dune'})<-[:WROTE]-(author:Person) RETURN author.name AS output");
90+
91+
// When
92+
List<Content> contents = retriever.retrieve(query);
93+
94+
// Then
95+
assertThat(contents).isEmpty();
96+
}
97+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
package dev.langchain4j.rag.content.retriever.neo4j;
2+
3+
import org.junit.jupiter.api.AfterAll;
4+
import org.junit.jupiter.api.AfterEach;
5+
import org.junit.jupiter.api.BeforeAll;
6+
import org.junit.jupiter.api.BeforeEach;
7+
import org.neo4j.driver.AuthTokens;
8+
import org.neo4j.driver.Driver;
9+
import org.neo4j.driver.GraphDatabase;
10+
import org.neo4j.driver.Session;
11+
import org.testcontainers.containers.Neo4jContainer;
12+
import org.testcontainers.junit.jupiter.Container;
13+
14+
public class Neo4jText2CypherRetrieverBaseTest {
15+
16+
protected static final String NEO4J_VERSION = System.getProperty("neo4jVersion", "5.26");
17+
18+
protected Driver driver;
19+
protected Neo4jGraph graph;
20+
21+
@Container
22+
protected static final Neo4jContainer<?> neo4jContainer = new Neo4jContainer<>("neo4j:" + NEO4J_VERSION)
23+
.withoutAuthentication()
24+
.withPlugins("apoc");
25+
26+
@BeforeAll
27+
static void beforeAll() {
28+
neo4jContainer.start();
29+
}
30+
31+
@AfterAll
32+
static void afterAll() {
33+
neo4jContainer.stop();
34+
}
35+
36+
@BeforeEach
37+
void beforeEach() {
38+
39+
driver = GraphDatabase.driver(neo4jContainer.getBoltUrl(), AuthTokens.none());
40+
41+
try (Session session = driver.session()) {
42+
session.run("CREATE (book:Book {title: 'Dune'})<-[:WROTE]-(author:Person {name: 'Frank Herbert'})");
43+
}
44+
45+
graph = Neo4jGraph.builder().driver(driver).build();
46+
}
47+
48+
@AfterEach
49+
void afterEach() {
50+
try (Session session = driver.session()) {
51+
session.run("MATCH (n) DETACH DELETE n");
52+
}
53+
graph.close();
54+
driver.close();
55+
}
56+
}

‎langchain4j-community-bom/pom.xml

+5
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,11 @@
9191
<artifactId>langchain4j-community-lucene</artifactId>
9292
<version>${project.version}</version>
9393
</dependency>
94+
<dependency>
95+
<groupId>dev.langchain4j</groupId>
96+
<artifactId>langchain4j-community-neo4j-retriever</artifactId>
97+
<version>${project.version}</version>
98+
</dependency>
9499

95100
<!-- web search engines -->
96101
<dependency>

‎pom.xml

+1
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565

6666
<!-- Integration of content retrievers -->
6767
<module>content-retrievers/langchain4j-community-lucene</module>
68+
<module>content-retrievers/langchain4j-community-neo4j-retriever</module>
6869

6970
<!-- Integration of web search engine -->
7071
<module>web-search-engines/langchain4j-community-web-search-engine-searxng</module>

0 commit comments

Comments
 (0)
Please sign in to comment.