diff --git a/backend/database/auth.py b/backend/database/auth.py
index 32005577c5..569a66c2f1 100644
--- a/backend/database/auth.py
+++ b/backend/database/auth.py
@@ -1,3 +1,5 @@
+import os
+
from firebase_admin import auth
from database.redis_db import cache_user_name, get_cached_user_name
@@ -10,6 +12,16 @@ def get_user_from_uid(uid: str):
print(e)
user = None
if not user:
+ if os.getenv('LOCAL_DEVELOPMENT') == 'true':
+ return {
+ 'uid': uid,
+ 'email': 'email',
+ 'email_verified': True,
+ 'phone_number': '',
+ 'display_name': 'Debug',
+ 'photo_url': None,
+ 'disabled': False,
+ }
return None
return {
diff --git a/backend/models/chat.py b/backend/models/chat.py
index 32c77533c5..fe6b84ec9b 100644
--- a/backend/models/chat.py
+++ b/backend/models/chat.py
@@ -136,6 +136,32 @@ def get_sender_name(message: Message) -> str:
return '\n'.join(formatted_messages)
+ @staticmethod
+ def get_messages_as_dict(
+ messages: List['Message'], use_user_name_if_available: bool = False, use_plugin_name_if_available: bool = False
+ ) -> List[dict]:
+ sorted_messages = sorted(messages, key=lambda m: m.created_at)
+
+ def get_sender_name(message: Message) -> str:
+ if message.sender == 'human':
+ return 'user'
+ # elif use_plugin_name_if_available and message.app_id is not None:
+ # plugin = next((p for p in plugins if p.id == message.app_id), None)
+ # if plugin:
+ # return plugin.name RESTORE ME
+ return message.sender # TODO: use app id
+
+ formatted_messages = [
+ {
+ 'role': get_sender_name(message),
+ 'content': message.text,
+ }
+ for message in sorted_messages
+ ]
+
+ return formatted_messages
+
+
class ResponseMessage(Message):
ask_for_nps: Optional[bool] = False
diff --git a/backend/models/conversation.py b/backend/models/conversation.py
index df58812a3e..7ebcfb856a 100644
--- a/backend/models/conversation.py
+++ b/backend/models/conversation.py
@@ -112,11 +112,12 @@ class ActionItem(BaseModel):
@staticmethod
def actions_to_string(action_items: List['ActionItem']) -> str:
- if not action_items:
- return 'None'
- return '\n'.join(
- [f"- {item.description} ({'completed' if item.completed else 'pending'})" for item in action_items]
- )
+ result = []
+ for item in action_items:
+ if isinstance(item, dict):
+ item = ActionItem(**item)
+ result.append(f"- {item.description} ({'completed' if item.completed else 'pending'})")
+ return '\n'.join(result)
class Event(BaseModel):
@@ -278,6 +279,60 @@ def __init__(self, **data):
# Update plugins_results based on apps_results
self.plugins_results = [PluginResult(plugin_id=app.app_id, content=app.content) for app in self.apps_results]
self.processing_memory_id = self.processing_conversation_id
+ #
+ # def model_dump_for_llm(self) -> str:
+ # d = self.model_dump(include={'category', 'title', 'overview'})
+ # return d
+
+ @staticmethod
+ def conversations_for_llm(
+ conversations: List['Conversation'],
+ use_transcript: bool = False,
+ include_timestamps: bool = False,
+ people: List[Person] = None,
+ ) -> List[dict]:
+ result = []
+ people_map = {p.id: p for p in people} if people else {}
+ for i, conversation in enumerate(conversations):
+ if isinstance(conversation, dict):
+ conversation = Conversation(**conversation)
+ item = {
+ 'index': i + 1,
+ 'category': str(conversation.structured.category.value),
+ 'title': str(conversation.structured.title),
+ 'overview': str(conversation.structured.overview),
+ 'created_at': conversation.created_at.astimezone(timezone.utc).strftime("%d %b %Y at %H:%M") + " UTC",
+ }
+
+ # attendees
+ if people_map:
+ conv_person_ids = set(conversation.get_person_ids())
+ if conv_person_ids:
+ attendees_names = [people_map[pid].name for pid in conv_person_ids if pid in people_map]
+ if attendees_names:
+ item['attendees'] = attendees_names
+
+ if conversation.structured.action_items:
+ item['actions'] = [item.description for item in conversation.structured.action_items]
+
+ if conversation.structured.events:
+ item['events'] = [{'title': event.title, 'start': event.start,
+ 'duration_minutes': event.duration} for event in conversation.structured.events]
+
+ if conversation.apps_results and len(conversation.apps_results) > 0:
+ item['summarization'] = conversation.apps_results[0].content
+
+ if use_transcript:
+ item['transcript'] = conversation.get_transcript(include_timestamps=include_timestamps, people=people)
+ # photos
+ photo_descriptions = conversation.get_photos_descriptions(include_timestamps=include_timestamps)
+ if photo_descriptions != 'None':
+ item['photos'] = photo_descriptions
+ # "type": "text"
+ # "text": "Your message here"
+ result.append(item)
+
+ return result
@staticmethod
def conversations_to_string(
diff --git a/backend/requirements.txt b/backend/requirements.txt
index cc5d68e143..1ede786dbc 100644
--- a/backend/requirements.txt
+++ b/backend/requirements.txt
@@ -92,9 +92,9 @@ jsonschema==4.23.0
jsonschema-specifications==2023.12.1
julius==0.2.7
kiwisolver==1.4.5
-langchain==0.3.4
-langchain-community==0.3.3
-langchain-core==0.3.12
+langchain==1.0.0a14
+langchain-community==0.3.31
+langchain-core==1.0.0a8
langchain-groq==0.2.0
langchain-openai==0.2.3
langchain-pinecone==0.2.0
diff --git a/backend/routers/chat.py b/backend/routers/chat.py
index 1d555467f1..4023e19850 100644
--- a/backend/routers/chat.py
+++ b/backend/routers/chat.py
@@ -1,3 +1,4 @@
+import os
import uuid
import re
import base64
@@ -35,6 +36,9 @@
from utils.other.chat_file import FileChatTool
from utils.retrieval.graph import execute_graph_chat, execute_graph_chat_stream, execute_persona_chat_stream
+if os.getenv('CHAT_AGENTIC') == 'true':
+ from utils.retrieval.agentic_graph import execute_graph_chat_stream
+
router = APIRouter()
diff --git a/backend/utils/llm/persona.py b/backend/utils/llm/persona.py
index 2109613f03..ff3d04373a 100644
--- a/backend/utils/llm/persona.py
+++ b/backend/utils/llm/persona.py
@@ -2,7 +2,13 @@
from models.app import App
from models.chat import Message, MessageSender
-from langchain.schema import SystemMessage, HumanMessage, AIMessage
+
+try:
+ from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
+except ImportError:
+ # Fallback for langchain<1.0.0
+ from langchain.schema import SystemMessage, HumanMessage, AIMessage
+
from .clients import llm_persona_mini_stream, llm_persona_medium_stream, llm_medium, llm_mini, llm_medium_experiment
diff --git a/backend/utils/retrieval/agentic_graph.py b/backend/utils/retrieval/agentic_graph.py
new file mode 100644
index 0000000000..407f8fa456
--- /dev/null
+++ b/backend/utils/retrieval/agentic_graph.py
@@ -0,0 +1,531 @@
+import os
+import uuid
+from dataclasses import dataclass
+from datetime import datetime, timezone
+from typing import List, Optional, AsyncGenerator, Annotated
+
+from langchain.agents import create_agent, AgentState
+from langchain.agents.middleware import AgentMiddleware
+from langchain.tools import InjectedState
+from langchain_core.messages import AIMessageChunk, AIMessage, ToolMessage
+from langchain_core.tools import tool, InjectedToolCallId, ToolException
+from langgraph.checkpoint.memory import MemorySaver
+from langgraph.runtime import get_runtime
+from langgraph.types import Command
+from pydantic import BaseModel
+
+import database.action_items as action_items_db
+import database.conversations as conversations_db
+import database.memories as memory_db
+import utils.apps as apps_utils
+from database.redis_db import get_filter_category_items
+from database.vector_db import query_vectors_by_metadata
+import database.notifications as notification_db
+from models.app import App
+from models.chat import ChatSession, Message
+import models.integrations as integration_models
+from models.conversation import Conversation, ActionItem
+from models.memories import Memory
+from utils.app_integrations import get_github_docs_content
+from utils.conversations.search import search_conversations
+from utils.llm.chat import (
+ retrieve_context_dates_by_question,
+ select_structured_filters
+)
+from utils.llm.clients import llm_mini_stream, llm_persona_medium_stream, llm_persona_mini_stream
+from utils.llms.memory import get_prompt_data
+from utils.other.chat_file import FileChatTool
+from utils.retrieval.state import BaseAgentState
+
+
+@dataclass
+class Context:
+ uid: str
+ tz: str
+ app: Optional[App] = None
+ chat_session: Optional[ChatSession] = None
+ files: Optional[List[str]] = None
+
+
+def retrieve_topics_filters(uid: str, question: str = "") -> dict:
+ print("retrieve_topics_filters")
+ filters = {
+ "people": get_filter_category_items(uid, "people", limit=1000),
+ "topics": get_filter_category_items(uid, "topics", limit=1000),
+ "entities": get_filter_category_items(uid, "entities", limit=1000),
+ # 'dates': get_filter_category_items(state.get('uid'), 'dates'),
+ }
+ result = select_structured_filters(question, filters)
+ filters = {
+ "topics": result.get("topics", []),
+ "people": result.get("people", []),
+ "entities": result.get("entities", []),
+ # 'dates': result.get('dates', []),
+ }
+ print("retrieve_topics_filters filters", filters)
+ return filters
+
+def retrieve_date_filters(tz: str = "UTC", question: str = ""):
+ print('retrieve_date_filters')
+
+ # TODO: if this makes vector search fail further, query firestore instead
+ dates_range = retrieve_context_dates_by_question(question, tz)
+ print('retrieve_date_filters dates_range:', dates_range)
+ if dates_range and len(dates_range) >= 2:
+ return {"start": dates_range[0], "end": dates_range[1]}
+ return {}
+
+def query_vectors(uid: str, question: str, tz: str = "UTC", limit: int = 100):
+ print("query_vectors")
+
+ # # stream
+ # if state.get('streaming', False):
+ # state['callback'].put_thought_nowait("Searching through your memories")
+
+ date_filters = retrieve_date_filters(tz, question)
+ filters = retrieve_topics_filters(uid, question)
+ # vector = (
+ # generate_embedding(state.get("parsed_question", ""))
+ # if state.get("parsed_question")
+ # else [0] * 3072
+ # )
+
+ # Use [1] * dimension to trigger the score distance to fetch all vectors by meta filters
+ vector = [1] * 3072
+ print("query_vectors vector:", vector[:5])
+
+ # TODO: enable it when the in-accurate topic filter get fixed
+ is_topic_filter_enabled = date_filters.get("start") is None
+ conversations_id = query_vectors_by_metadata(
+ uid,
+ vector,
+ dates_filter=[date_filters.get("start"), date_filters.get("end")],
+ people=filters.get("people", []) if is_topic_filter_enabled else [],
+ topics=filters.get("topics", []) if is_topic_filter_enabled else [],
+ entities=filters.get("entities", []) if is_topic_filter_enabled else [],
+ dates=filters.get("dates", []),
+ limit=100,
+ )
+ conversations = conversations_db.get_conversations_by_id(uid, conversations_id)
+
+ # Filter out locked conversations if user doesn't have premium access
+ conversations = [m for m in conversations if not m.get('is_locked', False)]
+
+ # stream
+ # if state.get('streaming', False):
+ # if len(memories) == 0:
+ # msg = "No relevant memories found"
+ # else:
+ # msg = f"Found {len(memories)} relevant memories"
+ # state['callback'].put_thought_nowait(msg)
+
+ # print(memories_id)
+ return conversations
+
+@tool
+def get_memories(
+ # question: str = "", Not implemented yet
+ limit: int = 100,
+ offset: int = 0,
+
+ state: Annotated[AgentState, InjectedState] = None,
+ tool_call_id: Annotated[str, InjectedToolCallId] = ""
+ ) -> Command:
+ """ Retrieve user memories.
+ """
+ print(f"get_memories")
+ runtime = get_runtime(Context)
+
+ memories = memory_db.get_memories(runtime.context.uid, limit=limit, offset=offset)
+ for memory in memories:
+ if memory.get('is_locked', False):
+ content = memory.get('content', '')
+ memory['content'] = (content[:70] + '...') if len(content) > 70 else content
+ memory['created_at'] = memory['created_at'].isoformat()
+ memory['updated_at'] = memory['updated_at'].isoformat()
+
+ memory_items = [integration_models.MemoryItem(**fact) for fact in memories]
+
+ # user_name, user_made_memories, generated_memories = get_prompt_data(runtime.context.uid)
+ # memories_str = (
+ # f'you already know the following facts about {user_name}: \n{Memory.get_memories_as_str(generated_memories)}.'
+ # )
+ # if user_made_memories:
+ # memories_str += (
+ # f'\n\n{user_name} also shared the following about self: \n{Memory.get_memories_as_str(user_made_memories)}'
+ # )
+ #
+ content = Memory.get_memories_as_str(memory_items)
+
+ return Command(update={
+ # 'memories': user_made_memories + generated_memories, # TODO: Refs to memories aren't supported yet
+ 'messages': [ToolMessage(content=content, tool_call_id=tool_call_id)]})
+
+@tool
+def get_conversations(
+ question: str = "",
+ page: int = 1,
+ per_page: int = 10,
+ include_discarded: bool = True,
+ start_date: str = None,
+ end_date: str = None,
+ state: Annotated[AgentState, InjectedState] = None,
+ tool_call_id: Annotated[str, InjectedToolCallId] = ""
+) -> Command:
+ """ Retrieve user conversations.
+
+ Args:
+ question (str): The question to filter memories.
+ page (int): The page number of the conversations to retrieve.
+ per_page (int): The number of conversations per page.
+ include_discarded (bool): Whether to include discarded conversations.
+ start_date (str): The start date and time in the ISO 8601 format (YYYY-MM-DDTHH:MM:SSZ) or YYYY-MM-DD.
+ end_date (str): The end date and time in the ISO 8601 format (YYYY-MM-DDTHH:MM:SSZ) or YYYY-MM-DD.
+ """
+ # Convert ISO datetime strings to Unix timestamps if provided
+ start_timestamp = None
+ end_timestamp = None
+ if isinstance(start_date, str) and start_date:
+ try:
+ start_date_str = start_date
+ if len(start_date_str) == 10: # YYYY-MM-DD
+ dt = datetime.strptime(start_date_str, '%Y-%m-%d')
+ start_dt = dt.replace(hour=0, minute=0, second=0, microsecond=0, tzinfo=timezone.utc)
+ else:
+ start_dt = datetime.fromisoformat(start_date_str.replace('Z', '+00:00'))
+ start_timestamp = int(start_dt.timestamp())
+ except ValueError:
+ raise ToolException("Error: Invalid start_date format. Use ISO format (YYYY-MM-DDTHH:MM:SSZ) or YYYY-MM-DD")
+
+ if isinstance(end_date, str) and end_date:
+ try:
+ end_date_str = end_date
+ if len(end_date_str) == 10: # YYYY-MM-DD
+ dt = datetime.strptime(end_date_str, '%Y-%m-%d')
+ end_dt = dt.replace(hour=23, minute=59, second=59, microsecond=999999, tzinfo=timezone.utc)
+ else:
+ end_dt = datetime.fromisoformat(end_date_str.replace('Z', '+00:00'))
+ end_timestamp = int(end_dt.timestamp())
+
+ except ValueError:
+ raise ToolException("Error: Invalid end_date format. Use ISO format (YYYY-MM-DDTHH:MM:SSZ) or YYYY-MM-DD")
+
+ print(f"get_conversations: {question}")
+ runtime = get_runtime(Context)
+
+ search_results = search_conversations(
+ query=question,
+ page=page,
+ per_page=per_page,
+ uid=runtime.context.uid,
+ include_discarded=include_discarded,
+ start_date=start_timestamp,
+ end_date=end_timestamp,
+ )
+
+ # Extract conversation IDs from search results
+ conversation_ids = [conv.get('id') for conv in search_results['items']]
+
+ # Get full conversation data using the IDs
+ full_conversations = []
+ if conversation_ids:
+ full_conversations = conversations_db.get_conversations_by_id(runtime.context.uid, conversation_ids)
+
+ # The old way
+ # conversations = query_vectors(runtime.context.uid, question, runtime.context.tz)
+ return Command(update={
+ 'conversations': full_conversations,
+ 'messages': [ToolMessage(content=Conversation.conversations_to_string(full_conversations), tool_call_id=tool_call_id)]})
+
+@tool
+def get_actions(question: str = "",
+ state: Annotated[AgentState, InjectedState] = None,
+ tool_call_id: Annotated[str, InjectedToolCallId] = ""
+ ) -> Command:
+ """ Retrieve user actions.
+
+ Args:
+ question (str): The question to filter user actions.
+ """
+ print(f"get_actions: {question}")
+ runtime = get_runtime(Context)
+
+ action_items = action_items_db.get_action_items(
+ uid=runtime.context.uid,
+ # conversation_id=conversation_id,
+ # completed=completed,
+ # start_date=start_date,
+ # end_date=end_date,
+ # limit=limit,
+ # offset=offset,
+ )
+
+ return Command(update={
+ 'messages': [ToolMessage(content=ActionItem.actions_to_string(action_items), tool_call_id=tool_call_id)]})
+
+@tool
+def get_omi_documentation():
+ """ Retrieve Omi device and app documentation to answer all questions like:
+ - How does it work?
+ - What can you do?
+ - How can I buy it?
+ - Where do I get it?
+ - How does the chat function?
+ """
+ context: dict = get_github_docs_content(path='docs')
+ return 'Documentation:\n\n'.join([f'{k}:\n {v}' for k, v in context.items()])
+
+class ResponseFormat(BaseModel):
+ answer: str
+ memories_found: Optional[List[str]]
+ ask_for_nps: bool = False
+
+# checkpointer = MemorySaver()
+
+# graph_stream = create_agent(
+# llm_mini_stream,
+# [
+# get_memories,
+# get_conversations,
+# # get_omi_documentation, TODO: Current doc must be formatted other way
+# ],
+# prompt="""You are a helpful assistant of wearable AI device named Omi.
+# Add text of memories returned by get_memories to response.""",
+# checkpointer=checkpointer,
+# response_format=ResponseFormat
+# )
+
+@tool
+def chat_file(question: str):
+ """ Process user inquires about files uploaded.
+ """
+ print(f"chat_file: {question}")
+ runtime = get_runtime(Context)
+ print(runtime.context.files)
+ fc_tool = FileChatTool(runtime.context.uid, runtime.context.chat_session.id)
+ answer = fc_tool.process_chat_with_file(question, runtime.context.files)
+ return AIMessage(content=answer)
+
+def get_files(messages: List[Message], chat_session: ChatSession = None):
+ last_message = messages[-1]
+ if len(last_message.files_id) > 0:
+ file_ids = last_message.files_id
+ elif chat_session:
+ file_ids = chat_session.file_ids
+ else:
+ file_ids = None
+
+ return file_ids
+
+class StateMiddleware(AgentMiddleware[BaseAgentState]):
+ state_schema = BaseAgentState
+
+PROMPT_BASE = """
+
+You are a helpful assistant of wearable AI device named Omi.
+
+
+
+- You can call available tools to gather context when needed
+
+{CHAT_FILES}
+
+
+- You MUST answer the question directly, concisely, and with high quality.
+- Refine the user's question using the last previous messages, but DO NOT use prior AI assistant messages as references/facts.
+- Prefer the user's memories and conversations when relevant; if they are empty or insufficient, still answer with existing general knowledge.
+- NEVER write phrases like "based on the available memories".
+- Time context: {TIME_CONTEXT}
+{PERSONA_PROMPT}
+{CITATIONS_PROMPT}
+
+{PERSONA_INSTRUCTIONS}
+{CITATIONS_INSTRUCTIONS}
+
+
+- If the user requests a report or summary of a period, structure the answer with:
+ - Goals and Achievements
+ - Mood Tracker
+ - Gratitude Log
+ - Lessons Learned
+
+"""
+# Add found memories if get_memories was called.
+
+def create_graph(
+ uid: str,
+ messages: List[Message],
+ app: Optional[App] = None,
+ cited: Optional[bool] = False,
+ callback_data: dict = {},
+ chat_session: Optional[ChatSession] = None,
+ files: Optional[List[str]] = None,
+ tz: Optional[str] = "UTC",
+):
+ tools = [
+ # get_omi_documentation, TODO: Current doc must be formatted other way
+ ]
+
+ if app is None or apps_utils.app_can_read_memories(app.model_dump(include={'external_integration'})):
+ tools.append(get_memories)
+
+ if app is None or apps_utils.app_can_read_conversations(app.model_dump(include={'external_integration'})):
+ tools.append(get_conversations)
+
+ # if app is None:
+ # tools.append(get_actions)
+
+ if files:
+ tools.append(chat_file)
+ prompt_chat_files = "Use chat_file tool if user asked about file uploaded.\n"
+ else:
+ prompt_chat_files = ""
+
+ checkpointer = MemorySaver()
+
+ # Dynamic prompt sections
+ citations_block = """
+
+- You MUST cite the most relevant memories or conversations that support your answer.
+- Cite using [index] at the end of sentences, e.g., "You discussed optimizing firmware yesterday[1][2]".
+- NO SPACE between the last word and the citation.
+- Avoid citing irrelevant items.
+
+""" if cited else ""
+
+ time_context = f"Question's timezone: {tz}. Current date time in UTC: {datetime.now().strftime('%Y-%m-%d %H:%M, %a')}"
+
+ model = llm_mini_stream
+
+ if app and app.is_a_persona():
+ if not os.getenv('LOCAL_DEVELOPMENT'):
+ if app.is_influencer:
+ model = llm_persona_medium_stream
+ else:
+ model = llm_persona_mini_stream
+ system_prompt = PROMPT_BASE.format(
+ CHAT_FILES=prompt_chat_files,
+ CITATIONS_PROMPT="""
+- You must always cite the most relevant memories and conversations if you used them in your answer. Expand it with .
+""" if citations_block else '',
+ CITATIONS_INSTRUCTIONS=citations_block,
+ TIME_CONTEXT=time_context,
+ PERSONA_PROMPT="""
+- Regard the
+""",
+ PERSONA_INSTRUCTIONS=f"""
+
+{app.persona_prompt}
+
+""")
+ else:
+ system_prompt = PROMPT_BASE.format(
+ CHAT_FILES=prompt_chat_files,
+ CITATIONS_PROMPT="""
+- Regard the
+""" if citations_block else '',
+ CITATIONS_INSTRUCTIONS=citations_block,
+ TIME_CONTEXT=time_context,
+ PERSONA_PROMPT='',
+ PERSONA_INSTRUCTIONS=''
+ )
+
+ graph = create_agent(
+ model,
+ tools,
+ system_prompt=system_prompt,
+ middleware=[StateMiddleware()],
+ checkpointer=checkpointer,
+ # response_format=ResponseFormat
+ )
+
+ return graph
+
+
+async def execute_graph_chat_stream(
+ uid: str,
+ messages: List[Message],
+ app: Optional[App] = None,
+ cited: Optional[bool] = False,
+ callback_data: dict = {},
+ chat_session: Optional[ChatSession] = None,
+) -> AsyncGenerator[str, None]:
+ print('execute_graph_chat_stream agentic app: ', app.id if app else '')
+ tz = notification_db.get_user_time_zone(uid)
+
+ files = get_files(messages, chat_session)
+
+ graph = create_graph(uid, messages, app, cited, callback_data, chat_session, files, tz)
+
+ async for event in graph.astream(
+ {
+ # uid and tz: Sent via Context
+ "cited": cited,
+ "messages": Message.get_messages_as_dict(messages),
+ # "app": app,
+ },
+ context=Context(uid=uid, tz=tz, app=app, chat_session=chat_session, files=files),
+ stream_mode=["messages", "custom", "updates"],
+ config={"configurable": {"thread_id": str(uuid.uuid4())}},
+ subgraphs=True,
+ ):
+ ns, stream_mode, payload = event
+ # print(ns, stream_mode, payload)
+ if stream_mode == "messages":
+ chunk, metadata = payload
+ metadata: dict
+ if chunk and isinstance(chunk, AIMessageChunk):
+ # Skip silent chunks (e.g., follow-up actions generation)
+ if metadata.get("silence"):
+ continue
+
+ content = str(chunk.content)
+ tool_calls = chunk.tool_calls
+
+ # Show tool execution progress
+ if tool_calls:
+ for tool_call in tool_calls:
+ tool_name_raw = tool_call.get("name")
+ print('tool_call', tool_name_raw)
+ if tool_name_raw:
+ tool_name = tool_name_raw.replace("_", " ").title()
+ yield f"think: Executing {tool_name}..."
+
+ # progress_data = format_tool_progress(tool_call)
+ # if progress_data:
+ # yield format_sse_data(progress_data)
+
+ # Only yield content from the main agent to avoid duplication
+ if content and len(ns) == 0:
+ yield f"data: {content}"
+ elif isinstance(chunk, ToolMessage):
+ chunk: ToolMessage
+ # if chunk.name in ['get_memories', 'get_conversations']:
+ # callback_data['memories_found'] = json.loads(chunk.content)
+ else:
+ # Pass other chunks like ToolMessage
+ pass
+
+ elif stream_mode == "updates":
+ payload: dict
+ if 'tools' in payload:
+ if not callback_data.get('memories_found'):
+ callback_data['memories_found'] = []
+ if payload['tools'].get('conversations'):
+ callback_data['memories_found'] += payload['tools']['conversations']
+ if payload['tools'].get('memories'):
+ callback_data['memories_found'] += payload['tools']['memories']
+
+ for k in ['model', 'agent']:
+ if k in payload:
+ last_message: AIMessage = payload[k]['messages'][0]
+ if last_message.response_metadata['finish_reason'] == 'stop':
+ callback_data['answer'] = last_message.content
+ # callback_data['answer'] = payload[k]['structured_response'].answer
+
+ elif stream_mode == "custom":
+ # Forward custom events as is
+ yield f"data: {payload}"
+
+ yield None
+ return
diff --git a/backend/utils/retrieval/graph.py b/backend/utils/retrieval/graph.py
index 08d01d5881..0c5aee6c7e 100644
--- a/backend/utils/retrieval/graph.py
+++ b/backend/utils/retrieval/graph.py
@@ -3,7 +3,13 @@
import asyncio
from typing import List, Optional, Tuple, AsyncGenerator
-from langchain.callbacks.base import BaseCallbackHandler
+# TODO Remove after complete upgrade to langchain>=1.0.0
+try:
+ from langchain_core.callbacks import BaseCallbackHandler
+except ImportError:
+ # Fallback for langchain<1.0.0
+ from langchain.callbacks.base import BaseCallbackHandler
+
from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.memory import MemorySaver
diff --git a/backend/utils/retrieval/state.py b/backend/utils/retrieval/state.py
new file mode 100644
index 0000000000..8dea584b3f
--- /dev/null
+++ b/backend/utils/retrieval/state.py
@@ -0,0 +1,13 @@
+from typing import NotRequired, TypedDict
+
+from langchain.agents import AgentState
+
+class Conversation(TypedDict):
+ text: str
+
+class Memory(TypedDict):
+ text: str
+
+class BaseAgentState(AgentState):
+ conversations: NotRequired[list[Conversation]]
+ memories: NotRequired[list[Memory]]