diff --git a/backend/score.py b/backend/score.py index ae1634bb6..7b5598081 100644 --- a/backend/score.py +++ b/backend/score.py @@ -8,6 +8,7 @@ import asyncio import base64 from langserve import add_routes +from langchain_core.messages import HumanMessage, AIMessage from langchain_google_vertexai import ChatVertexAI from src.api_response import create_api_response from src.graphDB_dataAccess import graphDBdataAccess @@ -375,7 +376,28 @@ async def post_processing(uri=Form(), userName=Form(), password=Form(), database finally: gc.collect() - + +@app.post('/retrieve_docs') +async def retrieve_docs(uri=Form(),model=Form(None),userName=Form(), password=Form(), database=Form(), document_names=Form(None), mode=Form(None), question=Form(None)): + if mode == "graph": + graph = Neo4jGraph( url=uri,username=userName,password=password,database=database,sanitize = True, refresh_schema=True) + else: + graph = create_graph_database_connection(uri, userName, password, database) + + chat_mode_settings = get_chat_mode_settings(mode=mode) + + messages = [HumanMessage(question)] + + try: + llm, doc_retriever, model_version = setup_chat(model, graph, document_names, chat_mode_settings) + docs, transformed_question = retrieve_documents(doc_retriever, messages) + return { + 'docs': docs, + 'transformed_question': transformed_question + } + except Exception as e: + return {'docs': [], 'transformed_question': None} + @app.post("/chat_bot") async def chat_bot(uri=Form(),model=Form(None),userName=Form(), password=Form(), database=Form(),question=Form(None), document_names=Form(None),session_id=Form(None),mode=Form(None),email=Form()): logging.info(f"QA_RAG called at {datetime.now()}") @@ -1053,4 +1075,4 @@ async def get_schema_visualization(uri=Form(), userName=Form(), password=Form(), gc.collect() if __name__ == "__main__": - uvicorn.run(app) \ No newline at end of file + uvicorn.run(app)