Skip to content

Commit b3ae46a

Browse files
authored
Add string input support for revised Neptune chains (#329)
Updated `create_neptune_opencypher_qa_chain` and `create_neptune_sparql_qa_chain` to accept base string type queries on invoke, in addition to the current dict format. This restores consistency with the input format of the older `langchain-community` Neptune chains.
1 parent 38c28fa commit b3ae46a

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed

libs/aws/langchain_aws/chains/graph_qa/neptune_cypher.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import re
4-
from typing import Any, Optional
4+
from typing import Any, Optional, Union
55

66
from langchain_core.language_models import BaseLanguageModel
77
from langchain_core.prompts.base import BasePromptTemplate
@@ -90,7 +90,7 @@ def create_neptune_opencypher_qa_chain(
9090
return_direct: bool = False,
9191
extra_instructions: Optional[str] = None,
9292
allow_dangerous_requests: bool = False,
93-
) -> Runnable[dict[str, Any], dict]:
93+
) -> Runnable:
9494
"""Chain for question-answering against a Neptune graph
9595
by generating openCypher statements.
9696
@@ -133,6 +133,11 @@ def create_neptune_opencypher_qa_chain(
133133
_cypher_prompt = cypher_prompt or get_prompt(llm)
134134
cypher_generation_chain = _cypher_prompt | llm
135135

136+
def normalize_input(raw_input: Union[str, dict]) -> dict:
137+
if isinstance(raw_input, str):
138+
return {"query": raw_input}
139+
return raw_input
140+
136141
def execute_graph_query(cypher_query: str) -> dict:
137142
return graph.query(cypher_query)
138143

@@ -164,7 +169,8 @@ def format_response(inputs: dict) -> dict:
164169
return final_response
165170

166171
chain_result = (
167-
RunnablePassthrough.assign(cypher_generation_inputs=get_cypher_inputs)
172+
normalize_input
173+
| RunnablePassthrough.assign(cypher_generation_inputs=get_cypher_inputs)
168174
| {
169175
"query": lambda x: x["query"],
170176
"cypher": (lambda x: x["cypher_generation_inputs"])

libs/aws/langchain_aws/chains/graph_qa/neptune_sparql.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from __future__ import annotations
66

7-
from typing import Any, Optional
7+
from typing import Any, Optional, Union
88

99
from langchain_core.language_models import BaseLanguageModel
1010
from langchain_core.prompts.base import BasePromptTemplate
@@ -64,7 +64,7 @@ def create_neptune_sparql_qa_chain(
6464
extra_instructions: Optional[str] = None,
6565
allow_dangerous_requests: bool = False,
6666
examples: Optional[str] = None,
67-
) -> Runnable[dict[str, Any], dict]:
67+
) -> Runnable[Any, dict]:
6868
"""Chain for question-answering against a Neptune graph
6969
by generating SPARQL statements.
7070
@@ -106,6 +106,11 @@ def create_neptune_sparql_qa_chain(
106106
_sparql_prompt = sparql_prompt or get_prompt(examples)
107107
sparql_generation_chain = _sparql_prompt | llm
108108

109+
def normalize_input(raw_input: Union[str, dict]) -> dict:
110+
if isinstance(raw_input, str):
111+
return {"query": raw_input}
112+
return raw_input
113+
109114
def execute_graph_query(sparql_query: str) -> dict:
110115
return graph.query(sparql_query)
111116

@@ -137,7 +142,8 @@ def format_response(inputs: dict) -> dict:
137142
return final_response
138143

139144
chain_result = (
140-
RunnablePassthrough.assign(sparql_generation_inputs=get_sparql_inputs)
145+
normalize_input
146+
| RunnablePassthrough.assign(sparql_generation_inputs=get_sparql_inputs)
141147
| {
142148
"query": lambda x: x["query"],
143149
"sparql": (lambda x: x["sparql_generation_inputs"])

0 commit comments

Comments
 (0)