Skip to content

Commit c4fa5eb

Browse files
authored
advanced rag script
The file that creates parent and child nodes from document, preparing for further multi-hop searches.
1 parent 681344e commit c4fa5eb

File tree

1 file changed

+203
-0
lines changed

1 file changed

+203
-0
lines changed

ingest.py

+203
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
from pathlib import Path
2+
from typing import List
3+
from langchain.chains.openai_functions import create_structured_output_chain
4+
from langchain.chat_models import ChatOpenAI
5+
from langchain.document_loaders import WikipediaLoader, PyPDFLoader, TextLoader
6+
from langchain.docstore.document import Document
7+
from langchain.embeddings.openai import OpenAIEmbeddings
8+
from langchain.graphs import Neo4jGraph
9+
from langchain.prompts import ChatPromptTemplate
10+
from langchain.pydantic_v1 import BaseModel, Field
11+
from langchain.text_splitter import TokenTextSplitter, CharacterTextSplitter
12+
from neo4j.exceptions import ClientError
13+
import os
14+
15+
graph = Neo4jGraph()
16+
17+
# Load Wikipedia Data
18+
all_data = WikipediaLoader(query="Removal_of_Sam_Altman_from_OpenAI").load()
19+
20+
# Embeddings & LLM models
21+
embeddings = OpenAIEmbeddings()
22+
embedding_dimension = 1536
23+
llm = ChatOpenAI(temperature=0)
24+
25+
# Process All Data
26+
parent_splitter = TokenTextSplitter(chunk_size=512, chunk_overlap=24)
27+
child_splitter = TokenTextSplitter(chunk_size=100, chunk_overlap=24)
28+
29+
# Ingest Parent-Child node pairs
30+
for document in all_data:
31+
parent_documents = parent_splitter.split_documents([document])
32+
for i, parent in enumerate(parent_documents):
33+
child_documents = child_splitter.split_documents([parent])
34+
params = {
35+
"parent_text": parent.page_content,
36+
"parent_id": i,
37+
"parent_embedding": embeddings.embed_query(parent.page_content),
38+
"children": [
39+
{
40+
"text": c.page_content,
41+
"id": f"{i}-{ic}",
42+
"embedding": embeddings.embed_query(c.page_content),
43+
}
44+
for ic, c in enumerate(child_documents)
45+
],
46+
}
47+
# Ingest data
48+
graph.query(
49+
"""
50+
MERGE (p:Parent {id: $parent_id})
51+
SET p.text = $parent_text
52+
WITH p
53+
CALL db.create.setVectorProperty(p, 'embedding', $parent_embedding)
54+
YIELD node
55+
WITH p
56+
UNWIND $children AS child
57+
MERGE (c:Child {id: child.id})
58+
SET c.text = child.text
59+
MERGE (c)<-[:HAS_CHILD]-(p)
60+
WITH c, child
61+
CALL db.create.setVectorProperty(c, 'embedding', child.embedding)
62+
YIELD node
63+
RETURN count(*)
64+
""",
65+
params,
66+
)
67+
# Create vector index for child
68+
try:
69+
graph.query(
70+
"CALL db.index.vector.createNodeIndex('parent_document', "
71+
"'Child', 'embedding', $dimension, 'cosine')",
72+
{"dimension": embedding_dimension},
73+
)
74+
except ClientError: # already exists
75+
pass
76+
# Create vector index for parents
77+
try:
78+
graph.query(
79+
"CALL db.index.vector.createNodeIndex('typical_rag', "
80+
"'Parent', 'embedding', $dimension, 'cosine')",
81+
{"dimension": embedding_dimension},
82+
)
83+
except ClientError: # already exists
84+
pass
85+
# Ingest hypothethical questions
86+
87+
88+
class Questions(BaseModel):
89+
"""Generating hypothetical questions about text."""
90+
91+
questions: List[str] = Field(
92+
...,
93+
description=(
94+
"Generated hypothetical questions based on " "the information from the text"
95+
),
96+
)
97+
98+
99+
questions_prompt = ChatPromptTemplate.from_messages(
100+
[
101+
(
102+
"system",
103+
(
104+
"You are generating hypothetical questions based on the information "
105+
"found in the text. Make sure to provide full context in the generated "
106+
"questions."
107+
),
108+
),
109+
(
110+
"human",
111+
(
112+
"Use the given format to generate hypothetical questions from the "
113+
"following input: {input}"
114+
),
115+
),
116+
]
117+
)
118+
119+
question_chain = create_structured_output_chain(Questions, llm, questions_prompt)
120+
121+
for i, parent in enumerate(parent_documents):
122+
questions = question_chain.run(parent.page_content).questions
123+
params = {
124+
"parent_id": i,
125+
"questions": [
126+
{"text": q, "id": f"{i}-{iq}", "embedding": embeddings.embed_query(q)}
127+
for iq, q in enumerate(questions)
128+
if q
129+
],
130+
}
131+
graph.query(
132+
"""
133+
MERGE (p:Parent {id: $parent_id})
134+
WITH p
135+
UNWIND $questions AS question
136+
CREATE (q:Question {id: question.id})
137+
SET q.text = question.text
138+
MERGE (q)<-[:HAS_QUESTION]-(p)
139+
WITH q, question
140+
CALL db.create.setVectorProperty(q, 'embedding', question.embedding)
141+
YIELD node
142+
RETURN count(*)
143+
""",
144+
params,
145+
)
146+
# Create vector index
147+
try:
148+
graph.query(
149+
"CALL db.index.vector.createNodeIndex('hypothetical_questions', "
150+
"'Question', 'embedding', $dimension, 'cosine')",
151+
{"dimension": embedding_dimension},
152+
)
153+
except ClientError: # already exists
154+
pass
155+
156+
# Ingest summaries
157+
158+
summary_prompt = ChatPromptTemplate.from_messages(
159+
[
160+
(
161+
"system",
162+
(
163+
"You are generating concise and accurate summaries based on the "
164+
"information found in the text."
165+
),
166+
),
167+
(
168+
"human",
169+
("Generate a summary of the following input: {question}\n" "Summary:"),
170+
),
171+
]
172+
)
173+
174+
summary_chain = summary_prompt | llm
175+
176+
for i, parent in enumerate(parent_documents):
177+
summary = summary_chain.invoke({"question": parent.page_content}).content
178+
params = {
179+
"parent_id": i,
180+
"summary": summary,
181+
"embedding": embeddings.embed_query(summary),
182+
}
183+
graph.query(
184+
"""
185+
MERGE (p:Parent {id: $parent_id})
186+
MERGE (p)-[:HAS_SUMMARY]->(s:Summary)
187+
SET s.text = $summary
188+
WITH s
189+
CALL db.create.setVectorProperty(s, 'embedding', $embedding)
190+
YIELD node
191+
RETURN count(*)
192+
""",
193+
params,
194+
)
195+
# Create vector index
196+
try:
197+
graph.query(
198+
"CALL db.index.vector.createNodeIndex('summary', "
199+
"'Summary', 'embedding', $dimension, 'cosine')",
200+
{"dimension": embedding_dimension},
201+
)
202+
except ClientError: # already exists
203+
pass

0 commit comments

Comments
 (0)