Skip to content

Commit ff8e656

Browse files
committed
Migrating neo4j-content-retriever
1 parent 027f4ac commit ff8e656

File tree

8 files changed

+537
-0
lines changed

8 files changed

+537
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
3+
<modelVersion>4.0.0</modelVersion>
4+
<parent>
5+
<groupId>dev.langchain4j</groupId>
6+
<artifactId>langchain4j-community</artifactId>
7+
<version>1.0.0-beta3-SNAPSHOT</version>
8+
<relativePath>../../pom.xml</relativePath>
9+
</parent>
10+
11+
<artifactId>langchain4j-community-neo4j</artifactId>
12+
<packaging>jar</packaging>
13+
<name>LangChain4j :: Community :: Content Retriever :: Neo4j</name>
14+
15+
<dependencies>
16+
17+
<dependency>
18+
<groupId>dev.langchain4j</groupId>
19+
<artifactId>langchain4j-core</artifactId>
20+
<version>${project.version}</version>
21+
</dependency>
22+
23+
<dependency>
24+
<groupId>org.slf4j</groupId>
25+
<artifactId>slf4j-api</artifactId>
26+
</dependency>
27+
28+
<dependency>
29+
<groupId>org.neo4j.driver</groupId>
30+
<artifactId>neo4j-java-driver</artifactId>
31+
<version>5.26.0</version>
32+
</dependency>
33+
34+
<!-- test dependencies -->
35+
<dependency>
36+
<groupId>org.testcontainers</groupId>
37+
<artifactId>junit-jupiter</artifactId>
38+
<scope>test</scope>
39+
</dependency>
40+
41+
<dependency>
42+
<groupId>org.testcontainers</groupId>
43+
<artifactId>neo4j</artifactId>
44+
<scope>test</scope>
45+
</dependency>
46+
47+
<dependency>
48+
<groupId>org.junit.jupiter</groupId>
49+
<artifactId>junit-jupiter</artifactId>
50+
<scope>test</scope>
51+
</dependency>
52+
53+
<dependency>
54+
<groupId>dev.langchain4j</groupId>
55+
<artifactId>langchain4j-embeddings-all-minilm-l6-v2-q</artifactId>
56+
<scope>test</scope>
57+
</dependency>
58+
59+
<dependency>
60+
<groupId>org.assertj</groupId>
61+
<artifactId>assertj-core</artifactId>
62+
<scope>test</scope>
63+
</dependency>
64+
65+
<dependency>
66+
<groupId>dev.langchain4j</groupId>
67+
<artifactId>langchain4j-open-ai</artifactId>
68+
<version>${project.version}</version>
69+
<scope>test</scope>
70+
</dependency>
71+
72+
</dependencies>
73+
74+
</project>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
package dev.langchain4j.rag.content.retriever.neo4j;
2+
3+
import static dev.langchain4j.internal.Utils.getOrDefault;
4+
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
5+
6+
import dev.langchain4j.model.chat.ChatLanguageModel;
7+
import dev.langchain4j.model.input.Prompt;
8+
import dev.langchain4j.model.input.PromptTemplate;
9+
import dev.langchain4j.rag.content.Content;
10+
import dev.langchain4j.rag.content.retriever.ContentRetriever;
11+
import dev.langchain4j.rag.query.Query;
12+
import java.util.List;
13+
import java.util.Map;
14+
import java.util.regex.Matcher;
15+
import java.util.regex.Pattern;
16+
import org.neo4j.driver.Record;
17+
import org.neo4j.driver.types.Type;
18+
import org.neo4j.driver.types.TypeSystem;
19+
20+
public class Neo4jContentRetriever implements ContentRetriever {
21+
22+
public static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = PromptTemplate.from(
23+
"""
24+
Based on the Neo4j graph schema below, write a Cypher query that would answer the user's question:
25+
{{schema}}
26+
27+
Question: {{question}}
28+
Cypher query:
29+
""");
30+
31+
private static final Pattern BACKTICKS_PATTERN = Pattern.compile("```(.*?)```", Pattern.MULTILINE | Pattern.DOTALL);
32+
private static final Type NODE = TypeSystem.getDefault().NODE();
33+
34+
private final Neo4jGraph graph;
35+
36+
private final ChatLanguageModel chatLanguageModel;
37+
38+
private final PromptTemplate promptTemplate;
39+
40+
public Neo4jContentRetriever(Neo4jGraph graph, ChatLanguageModel chatLanguageModel, PromptTemplate promptTemplate) {
41+
42+
this.graph = ensureNotNull(graph, "graph");
43+
this.chatLanguageModel = ensureNotNull(chatLanguageModel, "chatLanguageModel");
44+
this.promptTemplate = getOrDefault(promptTemplate, DEFAULT_PROMPT_TEMPLATE);
45+
}
46+
47+
public static Neo4jContentRetrieverBuilder builder() {
48+
return new Neo4jContentRetrieverBuilder();
49+
}
50+
51+
/*
52+
Getter methods
53+
*/
54+
public Neo4jGraph getGraph() {
55+
return graph;
56+
}
57+
58+
public ChatLanguageModel getChatLanguageModel() {
59+
return chatLanguageModel;
60+
}
61+
62+
public PromptTemplate getPromptTemplate() {
63+
return promptTemplate;
64+
}
65+
66+
@Override
67+
public List<Content> retrieve(Query query) {
68+
69+
String question = query.text();
70+
String schema = graph.getSchema();
71+
String cypherQuery = generateCypherQuery(schema, question);
72+
List<String> response = executeQuery(cypherQuery);
73+
return response.stream().map(Content::from).toList();
74+
}
75+
76+
private String generateCypherQuery(String schema, String question) {
77+
78+
Prompt cypherPrompt = promptTemplate.apply(Map.of("schema", schema, "question", question));
79+
String cypherQuery = chatLanguageModel.chat(cypherPrompt.text());
80+
Matcher matcher = BACKTICKS_PATTERN.matcher(cypherQuery);
81+
if (matcher.find()) {
82+
return matcher.group(1);
83+
}
84+
return cypherQuery;
85+
}
86+
87+
private List<String> executeQuery(String cypherQuery) {
88+
89+
List<Record> records = graph.executeRead(cypherQuery);
90+
return records.stream()
91+
.flatMap(r -> r.values().stream())
92+
.map(value -> NODE.isTypeOf(value) ? value.asMap().toString() : value.toString())
93+
.toList();
94+
}
95+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package dev.langchain4j.rag.content.retriever.neo4j;
2+
3+
import dev.langchain4j.model.chat.ChatLanguageModel;
4+
import dev.langchain4j.model.input.PromptTemplate;
5+
6+
public class Neo4jContentRetrieverBuilder {
7+
private Neo4jGraph graph;
8+
private ChatLanguageModel chatLanguageModel;
9+
private PromptTemplate promptTemplate;
10+
11+
/**
12+
* @param graph the {@link Neo4jGraph} (required)
13+
*/
14+
public Neo4jContentRetrieverBuilder graph(Neo4jGraph graph) {
15+
this.graph = graph;
16+
return this;
17+
}
18+
19+
/**
20+
* @param chatLanguageModel the {@link ChatLanguageModel} (required)
21+
*/
22+
public Neo4jContentRetrieverBuilder chatLanguageModel(ChatLanguageModel chatLanguageModel) {
23+
this.chatLanguageModel = chatLanguageModel;
24+
return this;
25+
}
26+
27+
/**
28+
* @param promptTemplate the {@link PromptTemplate} (optional, default is {@link Neo4jContentRetriever#DEFAULT_PROMPT_TEMPLATE})
29+
*/
30+
public Neo4jContentRetrieverBuilder promptTemplate(PromptTemplate promptTemplate) {
31+
this.promptTemplate = promptTemplate;
32+
return this;
33+
}
34+
35+
Neo4jContentRetriever build() {
36+
return new Neo4jContentRetriever(graph, chatLanguageModel, promptTemplate);
37+
}
38+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
package dev.langchain4j.rag.content.retriever.neo4j;
2+
3+
public class Neo4jException extends RuntimeException {
4+
5+
public Neo4jException(String message, Throwable cause) {
6+
7+
super(message, cause);
8+
}
9+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
package dev.langchain4j.rag.content.retriever.neo4j;
2+
3+
import dev.langchain4j.internal.ValidationUtils;
4+
import java.util.List;
5+
import java.util.Map;
6+
import java.util.stream.Collectors;
7+
import org.neo4j.driver.Driver;
8+
import org.neo4j.driver.Query;
9+
import org.neo4j.driver.Record;
10+
import org.neo4j.driver.Result;
11+
import org.neo4j.driver.Session;
12+
import org.neo4j.driver.Value;
13+
import org.neo4j.driver.exceptions.ClientException;
14+
import org.neo4j.driver.summary.ResultSummary;
15+
16+
public class Neo4jGraph implements AutoCloseable {
17+
18+
public static class Builder {
19+
private Driver driver;
20+
21+
/**
22+
* @param driver the {@link Driver} (required)
23+
*/
24+
Builder driver(Driver driver) {
25+
this.driver = driver;
26+
return this;
27+
}
28+
29+
Neo4jGraph build() {
30+
return new Neo4jGraph(driver);
31+
}
32+
}
33+
34+
private static final String NODE_PROPERTIES_QUERY =
35+
"""
36+
CALL apoc.meta.data()
37+
YIELD label, other, elementType, type, property
38+
WHERE NOT type = "RELATIONSHIP" AND elementType = "node"
39+
WITH label AS nodeLabels, collect({property:property, type:type}) AS properties
40+
RETURN {labels: nodeLabels, properties: properties} AS output
41+
""";
42+
43+
private static final String REL_PROPERTIES_QUERY =
44+
"""
45+
CALL apoc.meta.data()
46+
YIELD label, other, elementType, type, property
47+
WHERE NOT type = "RELATIONSHIP" AND elementType = "relationship"
48+
WITH label AS nodeLabels, collect({property:property, type:type}) AS properties
49+
RETURN {type: nodeLabels, properties: properties} AS output
50+
""";
51+
52+
private static final String RELATIONSHIPS_QUERY =
53+
"""
54+
CALL apoc.meta.data()
55+
YIELD label, other, elementType, type, property
56+
WHERE type = "RELATIONSHIP" AND elementType = "node"
57+
UNWIND other AS other_node
58+
RETURN {start: label, type: property, end: toString(other_node)} AS output
59+
""";
60+
61+
private final Driver driver;
62+
63+
private String schema;
64+
65+
public Neo4jGraph(final Driver driver) {
66+
67+
this.driver = ValidationUtils.ensureNotNull(driver, "driver");
68+
this.driver.verifyConnectivity();
69+
try {
70+
refreshSchema();
71+
} catch (ClientException e) {
72+
if ("Neo.ClientError.Procedure.ProcedureNotFound".equals(e.code())) {
73+
throw new Neo4jException("Please ensure the APOC plugin is installed in Neo4j", e);
74+
}
75+
throw e;
76+
}
77+
}
78+
79+
public String getSchema() {
80+
return schema;
81+
}
82+
83+
static Builder builder() {
84+
return new Builder();
85+
}
86+
87+
public ResultSummary executeWrite(String queryString) {
88+
89+
try (Session session = this.driver.session()) {
90+
return session.executeWrite(tx -> tx.run(queryString).consume());
91+
} catch (ClientException e) {
92+
throw new Neo4jException("Error executing query: " + queryString, e);
93+
}
94+
}
95+
96+
public List<Record> executeRead(String queryString) {
97+
98+
try (Session session = this.driver.session()) {
99+
return session.executeRead(tx -> {
100+
Query query = new Query(queryString);
101+
Result result = tx.run(query);
102+
return result.list();
103+
});
104+
} catch (ClientException e) {
105+
throw new Neo4jException("Error executing query: " + queryString, e);
106+
}
107+
}
108+
109+
public void refreshSchema() {
110+
111+
List<String> nodeProperties = formatNodeProperties(executeRead(NODE_PROPERTIES_QUERY));
112+
List<String> relationshipProperties = formatRelationshipProperties(executeRead(REL_PROPERTIES_QUERY));
113+
List<String> relationships = formatRelationships(executeRead(RELATIONSHIPS_QUERY));
114+
115+
this.schema = "Node properties are the following:\n" + String.join("\n", nodeProperties)
116+
+ "\n\n" + "Relationship properties are the following:\n"
117+
+ String.join("\n", relationshipProperties)
118+
+ "\n\n" + "The relationships are the following:\n"
119+
+ String.join("\n", relationships);
120+
}
121+
122+
private List<String> formatNodeProperties(List<Record> records) {
123+
124+
return records.stream()
125+
.map(this::getOutput)
126+
.map(r -> String.format(
127+
"%s %s",
128+
r.asMap().get("labels"), formatMap(r.get("properties").asList(Value::asMap))))
129+
.toList();
130+
}
131+
132+
private List<String> formatRelationshipProperties(List<Record> records) {
133+
134+
return records.stream()
135+
.map(this::getOutput)
136+
.map(r -> String.format(
137+
"%s %s", r.get("type"), formatMap(r.get("properties").asList(Value::asMap))))
138+
.toList();
139+
}
140+
141+
private List<String> formatRelationships(List<Record> records) {
142+
143+
return records.stream()
144+
.map(r -> getOutput(r).asMap())
145+
.map(r -> String.format("(:%s)-[:%s]->(:%s)", r.get("start"), r.get("type"), r.get("end")))
146+
.toList();
147+
}
148+
149+
private Value getOutput(Record record) {
150+
151+
return record.get("output");
152+
}
153+
154+
private String formatMap(List<Map<String, Object>> properties) {
155+
156+
return properties.stream()
157+
.map(prop -> prop.get("property") + ":" + prop.get("type"))
158+
.collect(Collectors.joining(", ", "{", "}"));
159+
}
160+
161+
@Override
162+
public void close() {
163+
164+
this.driver.close();
165+
}
166+
}

0 commit comments

Comments
 (0)