|
4 | 4 |
|
5 | 5 | from __future__ import annotations
|
6 | 6 |
|
7 |
| -from typing import Any, Optional |
| 7 | +from typing import Any, Optional, Union |
8 | 8 |
|
9 | 9 | from langchain_core.language_models import BaseLanguageModel
|
10 | 10 | from langchain_core.prompts.base import BasePromptTemplate
|
@@ -64,7 +64,7 @@ def create_neptune_sparql_qa_chain(
|
64 | 64 | extra_instructions: Optional[str] = None,
|
65 | 65 | allow_dangerous_requests: bool = False,
|
66 | 66 | examples: Optional[str] = None,
|
67 |
| -) -> Runnable[dict[str, Any], dict]: |
| 67 | +) -> Runnable[Any, dict]: |
68 | 68 | """Chain for question-answering against a Neptune graph
|
69 | 69 | by generating SPARQL statements.
|
70 | 70 |
|
@@ -106,6 +106,11 @@ def create_neptune_sparql_qa_chain(
|
106 | 106 | _sparql_prompt = sparql_prompt or get_prompt(examples)
|
107 | 107 | sparql_generation_chain = _sparql_prompt | llm
|
108 | 108 |
|
| 109 | + def normalize_input(raw_input: Union[str, dict]) -> dict: |
| 110 | + if isinstance(raw_input, str): |
| 111 | + return {"query": raw_input} |
| 112 | + return raw_input |
| 113 | + |
109 | 114 | def execute_graph_query(sparql_query: str) -> dict:
|
110 | 115 | return graph.query(sparql_query)
|
111 | 116 |
|
@@ -137,7 +142,8 @@ def format_response(inputs: dict) -> dict:
|
137 | 142 | return final_response
|
138 | 143 |
|
139 | 144 | chain_result = (
|
140 |
| - RunnablePassthrough.assign(sparql_generation_inputs=get_sparql_inputs) |
| 145 | + normalize_input |
| 146 | + | RunnablePassthrough.assign(sparql_generation_inputs=get_sparql_inputs) |
141 | 147 | | {
|
142 | 148 | "query": lambda x: x["query"],
|
143 | 149 | "sparql": (lambda x: x["sparql_generation_inputs"])
|
|
0 commit comments