diff --git a/.gitignore b/.gitignore index b4960ee62f..63cecd06cd 100644 --- a/.gitignore +++ b/.gitignore @@ -23,6 +23,7 @@ # Ignore .envrc files (used by direnv: https://direnv.net/) .envrc +certs/ # ignore storage dir /storage @@ -62,4 +63,4 @@ token.json # Coverage coverage_report -.coverage \ No newline at end of file +.coverage diff --git a/cookbook/README.md b/cookbook/README.md index cc24ca6286..cc0477f1e8 100644 --- a/cookbook/README.md +++ b/cookbook/README.md @@ -12,7 +12,7 @@ The concepts cookbook walks through the core concepts of Agno. - [RAG](./agent_concepts/rag) - [Knowledge](./agent_concepts/knowledge) - [Memory](./agent_concepts/memory) -- [Storage](./agent_concepts/storage) +- [Storage](storage) - [Tools](./agent_concepts/tools) - [Reasoning](./agent_concepts/reasoning) - [Vector DBs](./agent_concepts/vector_dbs) diff --git a/cookbook/examples/agents/agno_support_agent.py b/cookbook/examples/agents/agno_support_agent.py index c4e7554194..15381554b0 100644 --- a/cookbook/examples/agents/agno_support_agent.py +++ b/cookbook/examples/agents/agno_support_agent.py @@ -31,7 +31,6 @@ from agno.knowledge.url import UrlKnowledge from agno.models.openai import OpenAIChat from agno.storage.agent.sqlite import SqliteAgentStorage -from agno.tools.duckduckgo import DuckDuckGoTools from agno.tools.python import PythonTools from agno.vectordb.lancedb import LanceDb, SearchType from rich import print diff --git a/cookbook/getting_started/04_agent_with_storage.py b/cookbook/getting_started/04_agent_with_storage.py index 8dbbbf32e9..f7f2ebea70 100644 --- a/cookbook/getting_started/04_agent_with_storage.py +++ b/cookbook/getting_started/04_agent_with_storage.py @@ -23,7 +23,7 @@ from agno.embedder.openai import OpenAIEmbedder from agno.knowledge.pdf_url import PDFUrlKnowledgeBase from agno.models.openai import OpenAIChat -from agno.storage.agent.sqlite import SqliteAgentStorage +from agno.storage.sqlite import SqliteStorage from agno.tools.duckduckgo import DuckDuckGoTools from agno.vectordb.lancedb import LanceDb, SearchType from rich import print @@ -41,7 +41,7 @@ # if agent_knowledge is not None: # agent_knowledge.load() -agent_storage = SqliteAgentStorage(table_name="recipe_agent", db_file="tmp/agents.db") +agent_storage = SqliteStorage(table_name="recipe_agent", db_file="tmp/agents.db") def recipe_agent(user: str = "user"): diff --git a/cookbook/getting_started/09_research_workflow.py b/cookbook/getting_started/09_research_workflow.py index ea346c003b..f9626643d8 100644 --- a/cookbook/getting_started/09_research_workflow.py +++ b/cookbook/getting_started/09_research_workflow.py @@ -26,7 +26,7 @@ from agno.agent import Agent from agno.models.openai import OpenAIChat -from agno.storage.workflow.sqlite import SqliteWorkflowStorage +from agno.storage.sqlite import SqliteStorage from agno.tools.duckduckgo import DuckDuckGoTools from agno.tools.newspaper4k import Newspaper4kTools from agno.utils.log import logger @@ -408,7 +408,7 @@ def write_research_report( # Initialize the news report generator workflow generate_research_report = ResearchReportGenerator( session_id=f"generate-report-on-{url_safe_topic}", - storage=SqliteWorkflowStorage( + storage=SqliteStorage( table_name="generate_research_report_workflow", db_file="tmp/workflows.db", ), diff --git a/cookbook/getting_started/16_agent_session.py b/cookbook/getting_started/16_agent_session.py index 9b89afb1a9..2dec02f15c 100644 --- a/cookbook/getting_started/16_agent_session.py +++ b/cookbook/getting_started/16_agent_session.py @@ -17,7 +17,7 @@ import typer from agno.agent import Agent from agno.models.openai import OpenAIChat -from agno.storage.agent.sqlite import SqliteAgentStorage +from agno.storage.sqlite import SqliteStorage from rich import print from rich.console import Console from rich.json import JSON @@ -34,9 +34,7 @@ def create_agent(user: str = "user"): new = typer.confirm("Do you want to start a new session?") # Get existing session if user doesn't want a new one - agent_storage = SqliteAgentStorage( - table_name="agent_sessions", db_file="tmp/agents.db" - ) + agent_storage = SqliteStorage(table_name="agent_sessions", db_file="tmp/agents.db") if not new: existing_sessions = agent_storage.get_all_session_ids(user) diff --git a/cookbook/getting_started/17_user_memories_and_summaries.py b/cookbook/getting_started/17_user_memories_and_summaries.py index 48cba16adc..0c88480036 100644 --- a/cookbook/getting_started/17_user_memories_and_summaries.py +++ b/cookbook/getting_started/17_user_memories_and_summaries.py @@ -29,7 +29,7 @@ from agno.agent import Agent, AgentMemory from agno.memory.db.sqlite import SqliteMemoryDb from agno.models.openai import OpenAIChat -from agno.storage.agent.sqlite import SqliteAgentStorage +from agno.storage.sqlite import SqliteStorage from rich.console import Console from rich.json import JSON from rich.panel import Panel @@ -43,9 +43,7 @@ def create_agent(user: str = "user"): new = typer.confirm("Do you want to start a new session?") # Initialize storage for both agent sessions and memories - agent_storage = SqliteAgentStorage( - table_name="agent_memories", db_file="tmp/agents.db" - ) + agent_storage = SqliteStorage(table_name="agent_memories", db_file="tmp/agents.db") if not new: existing_sessions = agent_storage.get_all_session_ids(user) diff --git a/cookbook/hackathon/playground/blog_to_podcast.py b/cookbook/hackathon/playground/blog_to_podcast.py index 4b41d5e6a0..e5b160c293 100644 --- a/cookbook/hackathon/playground/blog_to_podcast.py +++ b/cookbook/hackathon/playground/blog_to_podcast.py @@ -9,7 +9,7 @@ from agno.agent import Agent from agno.models.openai import OpenAIChat from agno.playground import Playground, serve_playground_app -from agno.storage.agent.sqlite import SqliteAgentStorage +from agno.storage.sqlite import SqliteStorage from agno.tools.eleven_labs import ElevenLabsTools from agno.tools.firecrawl import FirecrawlTools @@ -43,7 +43,7 @@ markdown=True, debug_mode=True, add_history_to_messages=True, - storage=SqliteAgentStorage( + storage=SqliteStorage( table_name="blog_to_podcast_agent", db_file=image_agent_storage_file ), ) diff --git a/cookbook/hackathon/playground/demo.py b/cookbook/hackathon/playground/demo.py index bc25a3fb17..6fb99a707a 100644 --- a/cookbook/hackathon/playground/demo.py +++ b/cookbook/hackathon/playground/demo.py @@ -1,15 +1,11 @@ """Run `pip install openai exa_py duckduckgo-search yfinance pypdf sqlalchemy 'fastapi[standard]' youtube-transcript-api python-docx agno` to install dependencies.""" -from datetime import datetime -from textwrap import dedent - from agno.agent import Agent from agno.models.openai import OpenAIChat from agno.playground import Playground, serve_playground_app -from agno.storage.agent.sqlite import SqliteAgentStorage +from agno.storage.sqlite import SqliteStorage from agno.tools.dalle import DalleTools from agno.tools.duckduckgo import DuckDuckGoTools -from agno.tools.exa import ExaTools from agno.tools.yfinance import YFinanceTools from agno.tools.youtube import YouTubeTools @@ -22,7 +18,7 @@ role="Answer basic questions", agent_id="simple-agent", model=OpenAIChat(id="gpt-4o-mini"), - storage=SqliteAgentStorage(table_name="simple_agent", db_file=agent_storage_file), + storage=SqliteStorage(table_name="simple_agent", db_file=agent_storage_file), add_history_to_messages=True, num_history_responses=3, add_datetime_to_instructions=True, @@ -39,7 +35,7 @@ "Break down the users request into 2-3 different searches.", "Always include sources", ], - storage=SqliteAgentStorage(table_name="web_agent", db_file=agent_storage_file), + storage=SqliteStorage(table_name="web_agent", db_file=agent_storage_file), add_history_to_messages=True, num_history_responses=5, add_datetime_to_instructions=True, @@ -60,7 +56,7 @@ ) ], instructions=["Always use tables to display data"], - storage=SqliteAgentStorage(table_name="finance_agent", db_file=agent_storage_file), + storage=SqliteStorage(table_name="finance_agent", db_file=agent_storage_file), add_history_to_messages=True, num_history_responses=5, add_datetime_to_instructions=True, @@ -81,9 +77,7 @@ debug_mode=True, add_history_to_messages=True, add_datetime_to_instructions=True, - storage=SqliteAgentStorage( - table_name="image_agent", db_file=image_agent_storage_file - ), + storage=SqliteStorage(table_name="image_agent", db_file=image_agent_storage_file), ) youtube_agent = Agent( @@ -102,7 +96,7 @@ num_history_responses=5, show_tool_calls=True, add_datetime_to_instructions=True, - storage=SqliteAgentStorage(table_name="youtube_agent", db_file=agent_storage_file), + storage=SqliteStorage(table_name="youtube_agent", db_file=agent_storage_file), markdown=True, ) diff --git a/cookbook/hackathon/playground/multimodal_agents.py b/cookbook/hackathon/playground/multimodal_agents.py index 2debd7eeab..f20357e35d 100644 --- a/cookbook/hackathon/playground/multimodal_agents.py +++ b/cookbook/hackathon/playground/multimodal_agents.py @@ -10,7 +10,7 @@ from agno.models.openai import OpenAIChat from agno.models.response import FileType from agno.playground import Playground, serve_playground_app -from agno.storage.agent.sqlite import SqliteAgentStorage +from agno.storage.sqlite import SqliteStorage from agno.tools.dalle import DalleTools from agno.tools.eleven_labs import ElevenLabsTools from agno.tools.fal import FalTools @@ -33,9 +33,7 @@ debug_mode=True, add_history_to_messages=True, add_datetime_to_instructions=True, - storage=SqliteAgentStorage( - table_name="image_agent", db_file=image_agent_storage_file - ), + storage=SqliteStorage(table_name="image_agent", db_file=image_agent_storage_file), ) ml_gif_agent = Agent( @@ -52,9 +50,7 @@ debug_mode=True, add_history_to_messages=True, add_datetime_to_instructions=True, - storage=SqliteAgentStorage( - table_name="ml_gif_agent", db_file=image_agent_storage_file - ), + storage=SqliteStorage(table_name="ml_gif_agent", db_file=image_agent_storage_file), ) ml_music_agent = Agent( @@ -78,7 +74,7 @@ debug_mode=True, add_history_to_messages=True, add_datetime_to_instructions=True, - storage=SqliteAgentStorage( + storage=SqliteStorage( table_name="ml_music_agent", db_file=image_agent_storage_file ), ) @@ -97,7 +93,7 @@ debug_mode=True, add_history_to_messages=True, add_datetime_to_instructions=True, - storage=SqliteAgentStorage( + storage=SqliteStorage( table_name="ml_video_agent", db_file=image_agent_storage_file ), ) @@ -116,9 +112,7 @@ debug_mode=True, add_history_to_messages=True, add_datetime_to_instructions=True, - storage=SqliteAgentStorage( - table_name="fal_agent", db_file=image_agent_storage_file - ), + storage=SqliteStorage(table_name="fal_agent", db_file=image_agent_storage_file), ) gif_agent = Agent( @@ -135,9 +129,7 @@ debug_mode=True, add_history_to_messages=True, add_datetime_to_instructions=True, - storage=SqliteAgentStorage( - table_name="gif_agent", db_file=image_agent_storage_file - ), + storage=SqliteStorage(table_name="gif_agent", db_file=image_agent_storage_file), ) audio_agent = Agent( @@ -163,9 +155,7 @@ debug_mode=True, add_history_to_messages=True, add_datetime_to_instructions=True, - storage=SqliteAgentStorage( - table_name="audio_agent", db_file=image_agent_storage_file - ), + storage=SqliteStorage(table_name="audio_agent", db_file=image_agent_storage_file), ) diff --git a/cookbook/models/anthropic/memory.py b/cookbook/models/anthropic/memory.py index 9821634fdc..7004bbe3bf 100644 --- a/cookbook/models/anthropic/memory.py +++ b/cookbook/models/anthropic/memory.py @@ -9,8 +9,7 @@ from agno.agent import Agent, AgentMemory from agno.memory.db.postgres import PgMemoryDb from agno.models.anthropic import Claude -from agno.storage.agent.postgres import PostgresAgentStorage -from rich.pretty import pprint +from agno.storage.postgres import PostgresStorage db_url = "postgresql+psycopg://ai:ai@localhost:5532/ai" agent = Agent( @@ -22,9 +21,7 @@ create_session_summary=True, ), # Store agent sessions in a database - storage=PostgresAgentStorage( - table_name="personalized_agent_sessions", db_url=db_url - ), + storage=PostgresStorage(table_name="personalized_agent_sessions", db_url=db_url), # Show debug logs so, you can see the memory being created # debug_mode=True, ) diff --git a/cookbook/models/anthropic/storage.py b/cookbook/models/anthropic/storage.py index 6c6b1bf440..395110c6d7 100644 --- a/cookbook/models/anthropic/storage.py +++ b/cookbook/models/anthropic/storage.py @@ -2,14 +2,14 @@ from agno.agent import Agent from agno.models.anthropic import Claude -from agno.storage.agent.postgres import PostgresAgentStorage +from agno.storage.postgres import PostgresStorage from agno.tools.duckduckgo import DuckDuckGoTools db_url = "postgresql+psycopg://ai:ai@localhost:5532/ai" agent = Agent( model=Claude(id="claude-3-5-sonnet-20241022"), - storage=PostgresAgentStorage(table_name="agent_sessions", db_url=db_url), + storage=PostgresStorage(table_name="agent_sessions", db_url=db_url), tools=[DuckDuckGoTools()], add_history_to_messages=True, ) diff --git a/cookbook/models/aws/claude/storage.py b/cookbook/models/aws/claude/storage.py index 6288ba4ec1..1e9e4c9086 100644 --- a/cookbook/models/aws/claude/storage.py +++ b/cookbook/models/aws/claude/storage.py @@ -2,14 +2,14 @@ from agno.agent import Agent from agno.models.aws import Claude -from agno.storage.agent.postgres import PostgresAgentStorage +from agno.storage.postgres import PostgresStorage from agno.tools.duckduckgo import DuckDuckGoTools db_url = "postgresql+psycopg://ai:ai@localhost:5532/ai" agent = Agent( model=Claude(id="anthropic.claude-3-5-sonnet-20240620-v1:0"), - storage=PostgresAgentStorage(table_name="agent_sessions", db_url=db_url), + storage=PostgresStorage(table_name="agent_sessions", db_url=db_url), tools=[DuckDuckGoTools()], add_history_to_messages=True, ) diff --git a/cookbook/models/azure/ai_foundry/storage.py b/cookbook/models/azure/ai_foundry/storage.py index 31ca047610..6664571b07 100644 --- a/cookbook/models/azure/ai_foundry/storage.py +++ b/cookbook/models/azure/ai_foundry/storage.py @@ -2,14 +2,14 @@ from agno.agent import Agent from agno.models.azure import AzureAIFoundry -from agno.storage.agent.postgres import PostgresAgentStorage +from agno.storage.postgres import PostgresStorage from agno.tools.duckduckgo import DuckDuckGoTools db_url = "postgresql+psycopg://ai:ai@localhost:5532/ai" agent = Agent( model=AzureAIFoundry(id="Phi-4"), - storage=PostgresAgentStorage(table_name="agent_sessions", db_url=db_url), + storage=PostgresStorage(table_name="agent_sessions", db_url=db_url), tools=[DuckDuckGoTools()], add_history_to_messages=True, ) diff --git a/cookbook/models/azure/openai/storage.py b/cookbook/models/azure/openai/storage.py index 1e22cae34c..48e6f2e88e 100644 --- a/cookbook/models/azure/openai/storage.py +++ b/cookbook/models/azure/openai/storage.py @@ -2,14 +2,14 @@ from agno.agent import Agent from agno.models.azure import AzureOpenAI -from agno.storage.agent.postgres import PostgresAgentStorage +from agno.storage.postgres import PostgresStorage from agno.tools.duckduckgo import DuckDuckGoTools db_url = "postgresql+psycopg://ai:ai@localhost:5532/ai" agent = Agent( model=AzureOpenAI(id="gpt-4o-mini"), - storage=PostgresAgentStorage(table_name="agent_sessions", db_url=db_url), + storage=PostgresStorage(table_name="agent_sessions", db_url=db_url), tools=[DuckDuckGoTools()], add_history_to_messages=True, ) diff --git a/cookbook/models/cohere/memory.py b/cookbook/models/cohere/memory.py index c2e2b80d4b..a7d20ca14b 100644 --- a/cookbook/models/cohere/memory.py +++ b/cookbook/models/cohere/memory.py @@ -9,7 +9,7 @@ from agno.agent import Agent, AgentMemory from agno.memory.db.postgres import PgMemoryDb from agno.models.cohere import Cohere -from agno.storage.agent.postgres import PostgresAgentStorage +from agno.storage.postgres import PostgresStorage db_url = "postgresql+psycopg://ai:ai@localhost:5532/ai" agent = Agent( @@ -21,9 +21,7 @@ create_session_summary=True, ), # Store agent sessions in a database - storage=PostgresAgentStorage( - table_name="personalized_agent_sessions", db_url=db_url - ), + storage=PostgresStorage(table_name="personalized_agent_sessions", db_url=db_url), # Show debug logs so, you can see the memory being created # debug_mode=True, ) diff --git a/cookbook/models/cohere/storage.py b/cookbook/models/cohere/storage.py index 0c266bfb5a..4d88496f48 100644 --- a/cookbook/models/cohere/storage.py +++ b/cookbook/models/cohere/storage.py @@ -2,14 +2,14 @@ from agno.agent import Agent from agno.models.cohere import Cohere -from agno.storage.agent.postgres import PostgresAgentStorage +from agno.storage.postgres import PostgresStorage from agno.tools.duckduckgo import DuckDuckGoTools db_url = "postgresql+psycopg://ai:ai@localhost:5532/ai" agent = Agent( model=Cohere(id="command-r-08-2024"), - storage=PostgresAgentStorage(table_name="agent_sessions", db_url=db_url), + storage=PostgresStorage(table_name="agent_sessions", db_url=db_url), tools=[DuckDuckGoTools()], add_history_to_messages=True, ) diff --git a/cookbook/models/google/gemini/storage.py b/cookbook/models/google/gemini/storage.py index 879919b6ec..c4d2331e3b 100644 --- a/cookbook/models/google/gemini/storage.py +++ b/cookbook/models/google/gemini/storage.py @@ -2,14 +2,14 @@ from agno.agent import Agent from agno.models.google import Gemini -from agno.storage.agent.postgres import PostgresAgentStorage +from agno.storage.postgres import PostgresStorage from agno.tools.duckduckgo import DuckDuckGoTools db_url = "postgresql+psycopg://ai:ai@localhost:5532/ai" agent = Agent( model=Gemini(id="gemini-2.0-flash-exp"), - storage=PostgresAgentStorage(table_name="agent_sessions", db_url=db_url), + storage=PostgresStorage(table_name="agent_sessions", db_url=db_url), tools=[DuckDuckGoTools()], add_history_to_messages=True, ) diff --git a/cookbook/models/google/gemini/storage_and_memory.py b/cookbook/models/google/gemini/storage_and_memory.py index 77deba6566..6e6d2b9d0c 100644 --- a/cookbook/models/google/gemini/storage_and_memory.py +++ b/cookbook/models/google/gemini/storage_and_memory.py @@ -5,7 +5,7 @@ from agno.memory import AgentMemory from agno.memory.db.postgres import PgMemoryDb from agno.models.google import Gemini -from agno.storage.agent.postgres import PostgresAgentStorage +from agno.storage.postgres import PostgresStorage from agno.tools.duckduckgo import DuckDuckGoTools from agno.vectordb.pgvector import PgVector @@ -21,7 +21,7 @@ model=Gemini(id="gemini-2.0-flash-exp"), tools=[DuckDuckGoTools()], knowledge=knowledge_base, - storage=PostgresAgentStorage(table_name="agent_sessions", db_url=db_url), + storage=PostgresStorage(table_name="agent_sessions", db_url=db_url), # Store the memories and summary in a database memory=AgentMemory( db=PgMemoryDb(table_name="agent_memory", db_url=db_url), diff --git a/cookbook/models/groq/deep_knowledge.py b/cookbook/models/groq/deep_knowledge.py index df7e2ad275..80abc0a568 100644 --- a/cookbook/models/groq/deep_knowledge.py +++ b/cookbook/models/groq/deep_knowledge.py @@ -22,7 +22,7 @@ from agno.embedder.openai import OpenAIEmbedder from agno.knowledge.url import UrlKnowledge from agno.models.groq import Groq -from agno.storage.agent.sqlite import SqliteAgentStorage +from agno.storage.sqlite import SqliteStorage from agno.vectordb.lancedb import LanceDb, SearchType from rich import print @@ -47,9 +47,7 @@ def initialize_knowledge_base(): def get_agent_storage(): """Return agent storage""" - return SqliteAgentStorage( - table_name="deep_knowledge_sessions", db_file="tmp/agents.db" - ) + return SqliteStorage(table_name="deep_knowledge_sessions", db_file="tmp/agents.db") def create_agent(session_id: Optional[str] = None) -> Agent: diff --git a/cookbook/models/groq/storage.py b/cookbook/models/groq/storage.py index 66db3b0cc9..05f83f5045 100644 --- a/cookbook/models/groq/storage.py +++ b/cookbook/models/groq/storage.py @@ -2,14 +2,14 @@ from agno.agent import Agent from agno.models.groq import Groq -from agno.storage.agent.postgres import PostgresAgentStorage +from agno.storage.postgres import PostgresStorage from agno.tools.duckduckgo import DuckDuckGoTools db_url = "postgresql+psycopg://ai:ai@localhost:5532/ai" agent = Agent( model=Groq(id="llama-3.3-70b-versatile"), - storage=PostgresAgentStorage(table_name="agent_sessions", db_url=db_url), + storage=PostgresStorage(table_name="agent_sessions", db_url=db_url), tools=[DuckDuckGoTools()], add_history_to_messages=True, ) diff --git a/cookbook/models/ibm/watsonx/storage.py b/cookbook/models/ibm/watsonx/storage.py index 8a2c5f2445..19dff70877 100644 --- a/cookbook/models/ibm/watsonx/storage.py +++ b/cookbook/models/ibm/watsonx/storage.py @@ -2,14 +2,14 @@ from agno.agent import Agent from agno.models.ibm import WatsonX -from agno.storage.agent.postgres import PostgresAgentStorage +from agno.storage.postgres import PostgresStorage from agno.tools.duckduckgo import DuckDuckGoTools db_url = "postgresql+psycopg://ai:ai@localhost:5532/ai" agent = Agent( model=WatsonX(id="ibm/granite-20b-code-instruct"), - storage=PostgresAgentStorage(table_name="agent_sessions", db_url=db_url), + storage=PostgresStorage(table_name="agent_sessions", db_url=db_url), tools=[DuckDuckGoTools()], add_history_to_messages=True, ) diff --git a/cookbook/models/mistral/memory.py b/cookbook/models/mistral/memory.py index 168bc462d7..6f2ce735e2 100644 --- a/cookbook/models/mistral/memory.py +++ b/cookbook/models/mistral/memory.py @@ -9,7 +9,7 @@ from agno.agent import Agent, AgentMemory from agno.memory.db.postgres import PgMemoryDb from agno.models.mistral.mistral import MistralChat -from agno.storage.agent.postgres import PostgresAgentStorage +from agno.storage.postgres import PostgresStorage from agno.tools.duckduckgo import DuckDuckGoTools db_url = "postgresql+psycopg://ai:ai@localhost:5532/ai" @@ -23,9 +23,7 @@ create_session_summary=True, ), # Store agent sessions in a database - storage=PostgresAgentStorage( - table_name="personalized_agent_sessions", db_url=db_url - ), + storage=PostgresStorage(table_name="personalized_agent_sessions", db_url=db_url), show_tool_calls=True, # Show debug logs so, you can see the memory being created # debug_mode=True, diff --git a/cookbook/models/ollama/memory.py b/cookbook/models/ollama/memory.py index ce31066c82..fde0ecc6c0 100644 --- a/cookbook/models/ollama/memory.py +++ b/cookbook/models/ollama/memory.py @@ -9,7 +9,7 @@ from agno.agent import Agent, AgentMemory from agno.memory.db.postgres import PgMemoryDb from agno.models.ollama.chat import Ollama -from agno.storage.agent.postgres import PostgresAgentStorage +from agno.storage.postgres import PostgresStorage db_url = "postgresql+psycopg://ai:ai@localhost:5532/ai" agent = Agent( @@ -21,9 +21,7 @@ create_session_summary=True, ), # Store agent sessions in a database - storage=PostgresAgentStorage( - table_name="personalized_agent_sessions", db_url=db_url - ), + storage=PostgresStorage(table_name="personalized_agent_sessions", db_url=db_url), # Show debug logs so, you can see the memory being created # debug_mode=True, ) diff --git a/cookbook/models/ollama/storage.py b/cookbook/models/ollama/storage.py index b5ad2c630e..5a42313ba7 100644 --- a/cookbook/models/ollama/storage.py +++ b/cookbook/models/ollama/storage.py @@ -2,14 +2,14 @@ from agno.agent import Agent from agno.models.ollama import Ollama -from agno.storage.agent.postgres import PostgresAgentStorage +from agno.storage.postgres import PostgresStorage from agno.tools.duckduckgo import DuckDuckGoTools db_url = "postgresql+psycopg://ai:ai@localhost:5532/ai" agent = Agent( model=Ollama(id="llama3.1:8b"), - storage=PostgresAgentStorage(table_name="agent_sessions", db_url=db_url), + storage=PostgresStorage(table_name="agent_sessions", db_url=db_url), tools=[DuckDuckGoTools()], add_history_to_messages=True, ) diff --git a/cookbook/models/ollama_tools/storage.py b/cookbook/models/ollama_tools/storage.py index a3995a5a54..bc65f4c37c 100644 --- a/cookbook/models/ollama_tools/storage.py +++ b/cookbook/models/ollama_tools/storage.py @@ -2,14 +2,14 @@ from agno.agent import Agent from agno.models.ollama import OllamaTools -from agno.storage.agent.postgres import PostgresAgentStorage +from agno.storage.postgres import PostgresStorage from agno.tools.duckduckgo import DuckDuckGoTools db_url = "postgresql+psycopg://ai:ai@localhost:5532/ai" agent = Agent( model=OllamaTools(id="llama3.1:8b"), - storage=PostgresAgentStorage(table_name="agent_sessions", db_url=db_url), + storage=PostgresStorage(table_name="agent_sessions", db_url=db_url), tools=[DuckDuckGoTools()], add_history_to_messages=True, ) diff --git a/cookbook/models/openai/chat/memory.py b/cookbook/models/openai/chat/memory.py index 890f77278c..e708ad0e32 100644 --- a/cookbook/models/openai/chat/memory.py +++ b/cookbook/models/openai/chat/memory.py @@ -9,7 +9,7 @@ from agno.agent import Agent, AgentMemory from agno.memory.db.postgres import PgMemoryDb from agno.models.openai import OpenAIChat -from agno.storage.agent.postgres import PostgresAgentStorage +from agno.storage.postgres import PostgresStorage from rich.pretty import pprint db_url = "postgresql+psycopg://ai:ai@localhost:5532/ai" @@ -22,9 +22,7 @@ create_session_summary=True, ), # Store agent sessions in a database - storage=PostgresAgentStorage( - table_name="personalized_agent_sessions", db_url=db_url - ), + storage=PostgresStorage(table_name="personalized_agent_sessions", db_url=db_url), # Show debug logs so, you can see the memory being created # debug_mode=True, ) diff --git a/cookbook/models/openai/chat/storage.py b/cookbook/models/openai/chat/storage.py index f40f85e1be..36982e7ab3 100644 --- a/cookbook/models/openai/chat/storage.py +++ b/cookbook/models/openai/chat/storage.py @@ -2,14 +2,14 @@ from agno.agent import Agent from agno.models.openai import OpenAIChat -from agno.storage.agent.postgres import PostgresAgentStorage +from agno.storage.postgres import PostgresStorage from agno.tools.duckduckgo import DuckDuckGoTools db_url = "postgresql+psycopg://ai:ai@localhost:5532/ai" agent = Agent( model=OpenAIChat(id="gpt-4o"), - storage=PostgresAgentStorage(table_name="agent_sessions", db_url=db_url), + storage=PostgresStorage(table_name="agent_sessions", db_url=db_url), tools=[DuckDuckGoTools()], add_history_to_messages=True, ) diff --git a/cookbook/models/perplexity/memory.py b/cookbook/models/perplexity/memory.py index 2165319379..d3df1356e9 100644 --- a/cookbook/models/perplexity/memory.py +++ b/cookbook/models/perplexity/memory.py @@ -9,7 +9,7 @@ from agno.agent import Agent, AgentMemory from agno.memory.db.postgres import PgMemoryDb from agno.models.perplexity import Perplexity -from agno.storage.agent.postgres import PostgresAgentStorage +from agno.storage.postgres import PostgresStorage from rich.pretty import pprint db_url = "postgresql+psycopg://ai:ai@localhost:5532/ai" @@ -22,9 +22,7 @@ create_session_summary=True, ), # Store agent sessions in a database - storage=PostgresAgentStorage( - table_name="personalized_agent_sessions", db_url=db_url - ), + storage=PostgresStorage(table_name="personalized_agent_sessions", db_url=db_url), # Show debug logs so, you can see the memory being created # debug_mode=True, ) diff --git a/cookbook/playground/agno_assist.py b/cookbook/playground/agno_assist.py index 9d50692489..44de8bab5f 100644 --- a/cookbook/playground/agno_assist.py +++ b/cookbook/playground/agno_assist.py @@ -33,7 +33,7 @@ from agno.knowledge.url import UrlKnowledge from agno.models.openai import OpenAIChat from agno.playground import Playground, serve_playground_app -from agno.storage.agent.sqlite import SqliteAgentStorage +from agno.storage.sqlite import SqliteStorage from agno.tools.dalle import DalleTools from agno.tools.eleven_labs import ElevenLabsTools from agno.tools.python import PythonTools @@ -148,9 +148,7 @@ ), DalleTools(model="dall-e-3", size="1792x1024", quality="hd", style="vivid"), ], - storage=SqliteAgentStorage( - table_name="agno_assist_sessions", db_file="tmp/agents.db" - ), + storage=SqliteStorage(table_name="agno_assist_sessions", db_file="tmp/agents.db"), add_history_to_messages=True, add_datetime_to_instructions=True, markdown=True, @@ -168,9 +166,7 @@ instructions=_instructions, knowledge=agent_knowledge, tools=[PythonTools(base_dir=tmp_dir.joinpath("agents"), read_files=True)], - storage=SqliteAgentStorage( - table_name="agno_assist_sessions", db_file="tmp/agents.db" - ), + storage=SqliteStorage(table_name="agno_assist_sessions", db_file="tmp/agents.db"), add_history_to_messages=True, add_datetime_to_instructions=True, markdown=True, diff --git a/cookbook/playground/audio_conversation_agent.py b/cookbook/playground/audio_conversation_agent.py index 51c57e85d6..241bb248a1 100644 --- a/cookbook/playground/audio_conversation_agent.py +++ b/cookbook/playground/audio_conversation_agent.py @@ -1,7 +1,7 @@ from agno.agent import Agent from agno.models.openai import OpenAIChat from agno.playground import Playground, serve_playground_app -from agno.storage.agent.sqlite import SqliteAgentStorage +from agno.storage.sqlite import SqliteStorage db_url = "postgresql+psycopg://ai:ai@localhost:5532/ai" audio_and_text_agent = Agent( @@ -15,7 +15,7 @@ debug_mode=True, add_history_to_messages=True, add_datetime_to_instructions=True, - storage=SqliteAgentStorage(table_name="audio_agent", db_file="tmp/audio_agent.db"), + storage=SqliteStorage(table_name="audio_agent", db_file="tmp/audio_agent.db"), ) app = Playground(agents=[audio_and_text_agent]).get_app() diff --git a/cookbook/playground/azure_openai_agents.py b/cookbook/playground/azure_openai_agents.py index 696a6310a9..9415611ed1 100644 --- a/cookbook/playground/azure_openai_agents.py +++ b/cookbook/playground/azure_openai_agents.py @@ -6,7 +6,7 @@ from agno.agent import Agent from agno.models.azure.openai_chat import AzureOpenAI from agno.playground import Playground, serve_playground_app -from agno.storage.agent.sqlite import SqliteAgentStorage +from agno.storage.sqlite import SqliteStorage from agno.tools.dalle import DalleTools from agno.tools.duckduckgo import DuckDuckGoTools from agno.tools.exa import ExaTools @@ -25,7 +25,7 @@ "Break down the users request into 2-3 different searches.", "Always include sources", ], - storage=SqliteAgentStorage(table_name="web_agent", db_file=agent_storage_file), + storage=SqliteStorage(table_name="web_agent", db_file=agent_storage_file), add_history_to_messages=True, num_history_responses=5, add_datetime_to_instructions=True, @@ -46,7 +46,7 @@ ) ], instructions=["Always use tables to display data"], - storage=SqliteAgentStorage(table_name="finance_agent", db_file=agent_storage_file), + storage=SqliteStorage(table_name="finance_agent", db_file=agent_storage_file), add_history_to_messages=True, num_history_responses=5, add_datetime_to_instructions=True, @@ -67,7 +67,7 @@ debug_mode=True, add_history_to_messages=True, add_datetime_to_instructions=True, - storage=SqliteAgentStorage(table_name="image_agent", db_file=agent_storage_file), + storage=SqliteStorage(table_name="image_agent", db_file=agent_storage_file), ) research_agent = Agent( @@ -111,7 +111,7 @@ - [Reference 1](link) - [Reference 2](link) """), - storage=SqliteAgentStorage(table_name="research_agent", db_file=agent_storage_file), + storage=SqliteStorage(table_name="research_agent", db_file=agent_storage_file), add_history_to_messages=True, add_datetime_to_instructions=True, markdown=True, @@ -133,7 +133,7 @@ num_history_responses=5, show_tool_calls=True, add_datetime_to_instructions=True, - storage=SqliteAgentStorage(table_name="youtube_agent", db_file=agent_storage_file), + storage=SqliteStorage(table_name="youtube_agent", db_file=agent_storage_file), markdown=True, ) diff --git a/cookbook/playground/blog_to_podcast.py b/cookbook/playground/blog_to_podcast.py index 8b4c0281dd..c9963d7aac 100644 --- a/cookbook/playground/blog_to_podcast.py +++ b/cookbook/playground/blog_to_podcast.py @@ -1,7 +1,7 @@ from agno.agent import Agent from agno.models.openai import OpenAIChat from agno.playground import Playground, serve_playground_app -from agno.storage.agent.sqlite import SqliteAgentStorage +from agno.storage.sqlite import SqliteStorage from agno.tools.eleven_labs import ElevenLabsTools from agno.tools.firecrawl import FirecrawlTools @@ -35,7 +35,7 @@ markdown=True, debug_mode=True, add_history_to_messages=True, - storage=SqliteAgentStorage( + storage=SqliteStorage( table_name="blog_to_podcast_agent", db_file=image_agent_storage_file ), ) diff --git a/cookbook/playground/coding_agent.py b/cookbook/playground/coding_agent.py index 3cff49030b..59d6e717ce 100644 --- a/cookbook/playground/coding_agent.py +++ b/cookbook/playground/coding_agent.py @@ -3,7 +3,7 @@ from agno.agent import Agent from agno.models.ollama import Ollama from agno.playground import Playground, serve_playground_app -from agno.storage.agent.sqlite import SqliteAgentStorage +from agno.storage.sqlite import SqliteStorage local_agent_storage_file: str = "tmp/local_agents.db" common_instructions = [ @@ -19,9 +19,7 @@ add_history_to_messages=True, description="You are a coding agent", add_datetime_to_instructions=True, - storage=SqliteAgentStorage( - table_name="coding_agent", db_file=local_agent_storage_file - ), + storage=SqliteStorage(table_name="coding_agent", db_file=local_agent_storage_file), ) app = Playground(agents=[coding_agent]).get_app() diff --git a/cookbook/playground/demo.py b/cookbook/playground/demo.py index 0de8cc23ab..31ef7c9c90 100644 --- a/cookbook/playground/demo.py +++ b/cookbook/playground/demo.py @@ -6,7 +6,7 @@ from agno.agent import Agent from agno.models.openai import OpenAIChat from agno.playground import Playground, serve_playground_app -from agno.storage.agent.sqlite import SqliteAgentStorage +from agno.storage.sqlite import SqliteStorage from agno.tools.dalle import DalleTools from agno.tools.duckduckgo import DuckDuckGoTools from agno.tools.exa import ExaTools @@ -22,7 +22,7 @@ role="Answer basic questions", agent_id="simple-agent", model=OpenAIChat(id="gpt-4o-mini"), - storage=SqliteAgentStorage(table_name="simple_agent", db_file=agent_storage_file), + storage=SqliteStorage(table_name="simple_agent", db_file=agent_storage_file), add_history_to_messages=True, num_history_responses=3, add_datetime_to_instructions=True, @@ -39,7 +39,7 @@ "Break down the users request into 2-3 different searches.", "Always include sources", ], - storage=SqliteAgentStorage(table_name="web_agent", db_file=agent_storage_file), + storage=SqliteStorage(table_name="web_agent", db_file=agent_storage_file), add_history_to_messages=True, num_history_responses=5, add_datetime_to_instructions=True, @@ -60,7 +60,7 @@ ) ], instructions=["Always use tables to display data"], - storage=SqliteAgentStorage(table_name="finance_agent", db_file=agent_storage_file), + storage=SqliteStorage(table_name="finance_agent", db_file=agent_storage_file), add_history_to_messages=True, num_history_responses=5, add_datetime_to_instructions=True, @@ -81,9 +81,7 @@ debug_mode=True, add_history_to_messages=True, add_datetime_to_instructions=True, - storage=SqliteAgentStorage( - table_name="image_agent", db_file=image_agent_storage_file - ), + storage=SqliteStorage(table_name="image_agent", db_file=image_agent_storage_file), ) research_agent = Agent( @@ -127,7 +125,7 @@ - [Reference 1](link) - [Reference 2](link) """), - storage=SqliteAgentStorage(table_name="research_agent", db_file=agent_storage_file), + storage=SqliteStorage(table_name="research_agent", db_file=agent_storage_file), add_history_to_messages=True, add_datetime_to_instructions=True, markdown=True, @@ -149,7 +147,7 @@ num_history_responses=5, show_tool_calls=True, add_datetime_to_instructions=True, - storage=SqliteAgentStorage(table_name="youtube_agent", db_file=agent_storage_file), + storage=SqliteStorage(table_name="youtube_agent", db_file=agent_storage_file), markdown=True, ) diff --git a/cookbook/playground/grok_agents.py b/cookbook/playground/grok_agents.py index a7679631d4..6e67f161ba 100644 --- a/cookbook/playground/grok_agents.py +++ b/cookbook/playground/grok_agents.py @@ -6,7 +6,7 @@ from agno.agent import Agent from agno.models.xai import xAI from agno.playground import Playground, serve_playground_app -from agno.storage.agent.sqlite import SqliteAgentStorage +from agno.storage.sqlite import SqliteStorage from agno.tools.duckduckgo import DuckDuckGoTools from agno.tools.yfinance import YFinanceTools from agno.tools.youtube import YouTubeTools @@ -27,7 +27,7 @@ "Always include sources you used to generate the answer.", ] + common_instructions, - storage=SqliteAgentStorage(table_name="web_agent", db_file=xai_agent_storage), + storage=SqliteStorage(table_name="web_agent", db_file=xai_agent_storage), show_tool_calls=True, add_history_to_messages=True, num_history_responses=2, @@ -51,7 +51,7 @@ ], description="You are an investment analyst that researches stocks and helps users make informed decisions.", instructions=["Always use tables to display data"] + common_instructions, - storage=SqliteAgentStorage(table_name="finance_agent", db_file=xai_agent_storage), + storage=SqliteStorage(table_name="finance_agent", db_file=xai_agent_storage), show_tool_calls=True, add_history_to_messages=True, num_history_responses=5, @@ -75,7 +75,7 @@ "Keep your answers concise and engaging.", ] + common_instructions, - storage=SqliteAgentStorage(table_name="youtube_agent", db_file=xai_agent_storage), + storage=SqliteStorage(table_name="youtube_agent", db_file=xai_agent_storage), show_tool_calls=True, add_history_to_messages=True, num_history_responses=5, diff --git a/cookbook/playground/groq_agents.py b/cookbook/playground/groq_agents.py index d18d0e421b..1803ab6574 100644 --- a/cookbook/playground/groq_agents.py +++ b/cookbook/playground/groq_agents.py @@ -6,7 +6,7 @@ from agno.agent import Agent from agno.models.groq import Groq from agno.playground import Playground, serve_playground_app -from agno.storage.agent.sqlite import SqliteAgentStorage +from agno.storage.sqlite import SqliteStorage from agno.tools.duckduckgo import DuckDuckGoTools from agno.tools.yfinance import YFinanceTools from agno.tools.youtube import YouTubeTools @@ -27,7 +27,7 @@ "Always include sources you used to generate the answer.", ] + common_instructions, - storage=SqliteAgentStorage(table_name="web_agent", db_file=xai_agent_storage), + storage=SqliteStorage(table_name="web_agent", db_file=xai_agent_storage), show_tool_calls=True, add_history_to_messages=True, num_history_responses=2, @@ -51,7 +51,7 @@ ], description="You are an investment analyst that researches stocks and helps users make informed decisions.", instructions=["Always use tables to display data"] + common_instructions, - storage=SqliteAgentStorage(table_name="finance_agent", db_file=xai_agent_storage), + storage=SqliteStorage(table_name="finance_agent", db_file=xai_agent_storage), show_tool_calls=True, add_history_to_messages=True, num_history_responses=5, @@ -76,7 +76,7 @@ "If the user just provides a URL, summarize the video and answer questions about it.", ] + common_instructions, - storage=SqliteAgentStorage(table_name="youtube_agent", db_file=xai_agent_storage), + storage=SqliteStorage(table_name="youtube_agent", db_file=xai_agent_storage), show_tool_calls=True, add_history_to_messages=True, num_history_responses=5, diff --git a/cookbook/playground/multimodal_agents.py b/cookbook/playground/multimodal_agents.py index 2debd7eeab..f20357e35d 100644 --- a/cookbook/playground/multimodal_agents.py +++ b/cookbook/playground/multimodal_agents.py @@ -10,7 +10,7 @@ from agno.models.openai import OpenAIChat from agno.models.response import FileType from agno.playground import Playground, serve_playground_app -from agno.storage.agent.sqlite import SqliteAgentStorage +from agno.storage.sqlite import SqliteStorage from agno.tools.dalle import DalleTools from agno.tools.eleven_labs import ElevenLabsTools from agno.tools.fal import FalTools @@ -33,9 +33,7 @@ debug_mode=True, add_history_to_messages=True, add_datetime_to_instructions=True, - storage=SqliteAgentStorage( - table_name="image_agent", db_file=image_agent_storage_file - ), + storage=SqliteStorage(table_name="image_agent", db_file=image_agent_storage_file), ) ml_gif_agent = Agent( @@ -52,9 +50,7 @@ debug_mode=True, add_history_to_messages=True, add_datetime_to_instructions=True, - storage=SqliteAgentStorage( - table_name="ml_gif_agent", db_file=image_agent_storage_file - ), + storage=SqliteStorage(table_name="ml_gif_agent", db_file=image_agent_storage_file), ) ml_music_agent = Agent( @@ -78,7 +74,7 @@ debug_mode=True, add_history_to_messages=True, add_datetime_to_instructions=True, - storage=SqliteAgentStorage( + storage=SqliteStorage( table_name="ml_music_agent", db_file=image_agent_storage_file ), ) @@ -97,7 +93,7 @@ debug_mode=True, add_history_to_messages=True, add_datetime_to_instructions=True, - storage=SqliteAgentStorage( + storage=SqliteStorage( table_name="ml_video_agent", db_file=image_agent_storage_file ), ) @@ -116,9 +112,7 @@ debug_mode=True, add_history_to_messages=True, add_datetime_to_instructions=True, - storage=SqliteAgentStorage( - table_name="fal_agent", db_file=image_agent_storage_file - ), + storage=SqliteStorage(table_name="fal_agent", db_file=image_agent_storage_file), ) gif_agent = Agent( @@ -135,9 +129,7 @@ debug_mode=True, add_history_to_messages=True, add_datetime_to_instructions=True, - storage=SqliteAgentStorage( - table_name="gif_agent", db_file=image_agent_storage_file - ), + storage=SqliteStorage(table_name="gif_agent", db_file=image_agent_storage_file), ) audio_agent = Agent( @@ -163,9 +155,7 @@ debug_mode=True, add_history_to_messages=True, add_datetime_to_instructions=True, - storage=SqliteAgentStorage( - table_name="audio_agent", db_file=image_agent_storage_file - ), + storage=SqliteStorage(table_name="audio_agent", db_file=image_agent_storage_file), ) diff --git a/cookbook/playground/ollama_agents.py b/cookbook/playground/ollama_agents.py index b93afca43d..7a897d5c5d 100644 --- a/cookbook/playground/ollama_agents.py +++ b/cookbook/playground/ollama_agents.py @@ -3,7 +3,7 @@ from agno.agent import Agent from agno.models.ollama import Ollama from agno.playground import Playground, serve_playground_app -from agno.storage.agent.sqlite import SqliteAgentStorage +from agno.storage.sqlite import SqliteStorage from agno.tools.duckduckgo import DuckDuckGoTools from agno.tools.yfinance import YFinanceTools from agno.tools.youtube import YouTubeTools @@ -20,9 +20,7 @@ model=Ollama(id="llama3.1:8b"), tools=[DuckDuckGoTools()], instructions=["Always include sources."] + common_instructions, - storage=SqliteAgentStorage( - table_name="web_agent", db_file=local_agent_storage_file - ), + storage=SqliteStorage(table_name="web_agent", db_file=local_agent_storage_file), show_tool_calls=True, add_history_to_messages=True, num_history_responses=2, @@ -46,9 +44,7 @@ ], description="You are an investment analyst that researches stocks and helps users make informed decisions.", instructions=["Always use tables to display data"] + common_instructions, - storage=SqliteAgentStorage( - table_name="finance_agent", db_file=local_agent_storage_file - ), + storage=SqliteStorage(table_name="finance_agent", db_file=local_agent_storage_file), add_history_to_messages=True, num_history_responses=5, add_name_to_instructions=True, @@ -76,9 +72,7 @@ show_tool_calls=True, add_name_to_instructions=True, add_datetime_to_instructions=True, - storage=SqliteAgentStorage( - table_name="youtube_agent", db_file=local_agent_storage_file - ), + storage=SqliteStorage(table_name="youtube_agent", db_file=local_agent_storage_file), markdown=True, ) diff --git a/cookbook/playground/upload_files.py b/cookbook/playground/upload_files.py index 11b3b1dd12..70d2f73169 100644 --- a/cookbook/playground/upload_files.py +++ b/cookbook/playground/upload_files.py @@ -9,7 +9,7 @@ from agno.models.openai import OpenAIChat from agno.playground.playground import Playground from agno.playground.serve import serve_playground_app -from agno.storage.agent.postgres import PostgresAgentStorage +from agno.storage.postgres import PostgresStorage from agno.vectordb.pgvector import PgVector db_url = "postgresql+psycopg://ai:ai@localhost:5532/ai" @@ -40,7 +40,7 @@ agent_id="file-upload-agent", role="Answer questions about the uploaded files", model=OpenAIChat(id="gpt-4o-mini"), - storage=PostgresAgentStorage(table_name="agent_sessions", db_url=db_url), + storage=PostgresStorage(table_name="agent_sessions", db_url=db_url), knowledge=knowledge_base, show_tool_calls=True, markdown=True, @@ -52,7 +52,7 @@ agent_id="audio-understanding-agent", role="Answer questions about audio files", model=OpenAIChat(id="gpt-4o-audio-preview"), - storage=PostgresAgentStorage(table_name="agent_sessions", db_url=db_url), + storage=PostgresStorage(table_name="agent_sessions", db_url=db_url), add_history_to_messages=True, add_datetime_to_instructions=True, show_tool_calls=True, @@ -64,7 +64,7 @@ model=Gemini(id="gemini-2.0-flash"), agent_id="video-understanding-agent", role="Answer questions about video files", - storage=PostgresAgentStorage(table_name="agent_sessions", db_url=db_url), + storage=PostgresStorage(table_name="agent_sessions", db_url=db_url), add_history_to_messages=True, add_datetime_to_instructions=True, show_tool_calls=True, diff --git a/cookbook/agent_concepts/storage/__init__.py b/cookbook/storage/__init__.py similarity index 100% rename from cookbook/agent_concepts/storage/__init__.py rename to cookbook/storage/__init__.py diff --git a/cookbook/storage/dynamodb_storage/__init__.py b/cookbook/storage/dynamodb_storage/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/cookbook/agent_concepts/storage/dynamodb_storage.py b/cookbook/storage/dynamodb_storage/dynamodb_storage_for_agent.py similarity index 71% rename from cookbook/agent_concepts/storage/dynamodb_storage.py rename to cookbook/storage/dynamodb_storage/dynamodb_storage_for_agent.py index f0629d60d3..87137e1add 100644 --- a/cookbook/agent_concepts/storage/dynamodb_storage.py +++ b/cookbook/storage/dynamodb_storage/dynamodb_storage_for_agent.py @@ -1,11 +1,11 @@ """Run `pip install duckduckgo-search boto3 openai` to install dependencies.""" from agno.agent import Agent -from agno.storage.agent.dynamodb import DynamoDbAgentStorage +from agno.storage.dynamodb import DynamoDbStorage from agno.tools.duckduckgo import DuckDuckGoTools agent = Agent( - storage=DynamoDbAgentStorage(table_name="agent_sessions", region_name="us-east-1"), + storage=DynamoDbStorage(table_name="agent_sessions", region_name="us-east-1"), tools=[DuckDuckGoTools()], add_history_to_messages=True, debug_mode=True, diff --git a/cookbook/storage/json_storage/__init__.py b/cookbook/storage/json_storage/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/cookbook/agent_concepts/storage/json_storage.py b/cookbook/storage/json_storage/json_storage_for_agent.py similarity index 74% rename from cookbook/agent_concepts/storage/json_storage.py rename to cookbook/storage/json_storage/json_storage_for_agent.py index 7afbd36b0c..a3f091bfad 100644 --- a/cookbook/agent_concepts/storage/json_storage.py +++ b/cookbook/storage/json_storage/json_storage_for_agent.py @@ -1,11 +1,11 @@ """Run `pip install duckduckgo-search openai` to install dependencies.""" from agno.agent import Agent -from agno.storage.agent.json import JsonAgentStorage +from agno.storage.json import JsonStorage from agno.tools.duckduckgo import DuckDuckGoTools agent = Agent( - storage=JsonAgentStorage(dir_path="tmp/agent_sessions_json"), + storage=JsonStorage(dir_path="tmp/agent_sessions_json"), tools=[DuckDuckGoTools()], add_history_to_messages=True, ) diff --git a/cookbook/storage/json_storage/json_storage_for_workflow.py b/cookbook/storage/json_storage/json_storage_for_workflow.py new file mode 100644 index 0000000000..7e76fc98a5 --- /dev/null +++ b/cookbook/storage/json_storage/json_storage_for_workflow.py @@ -0,0 +1,88 @@ +import json +from typing import Iterator + +import httpx +from agno.agent import Agent +from agno.run.response import RunResponse +from agno.storage.json import JsonStorage +from agno.tools.newspaper4k import Newspaper4kTools +from agno.utils.log import logger +from agno.utils.pprint import pprint_run_response +from agno.workflow import Workflow + + +class HackerNewsReporter(Workflow): + description: str = ( + "Get the top stories from Hacker News and write a report on them." + ) + + hn_agent: Agent = Agent( + description="Get the top stories from hackernews. " + "Share all possible information, including url, score, title and summary if available.", + show_tool_calls=True, + ) + + writer: Agent = Agent( + tools=[Newspaper4kTools()], + description="Write an engaging report on the top stories from hackernews.", + instructions=[ + "You will be provided with top stories and their links.", + "Carefully read each article and think about the contents", + "Then generate a final New York Times worthy article", + "Break the article into sections and provide key takeaways at the end.", + "Make sure the title is catchy and engaging.", + "Share score, title, url and summary of every article.", + "Give the section relevant titles and provide details/facts/processes in each section." + "Ignore articles that you cannot read or understand.", + "REMEMBER: you are writing for the New York Times, so the quality of the article is important.", + ], + ) + + def get_top_hackernews_stories(self, num_stories: int = 10) -> str: + """Use this function to get top stories from Hacker News. + + Args: + num_stories (int): Number of stories to return. Defaults to 10. + + Returns: + str: JSON string of top stories. + """ + + # Fetch top story IDs + response = httpx.get("https://hacker-news.firebaseio.com/v0/topstories.json") + story_ids = response.json() + + # Fetch story details + stories = [] + for story_id in story_ids[:num_stories]: + story_response = httpx.get( + f"https://hacker-news.firebaseio.com/v0/item/{story_id}.json" + ) + story = story_response.json() + story["username"] = story["by"] + stories.append(story) + return json.dumps(stories) + + def run(self, num_stories: int = 5) -> Iterator[RunResponse]: + # Set the tools for hn_agent here to avoid circular reference + self.hn_agent.tools = [self.get_top_hackernews_stories] + + logger.info(f"Getting top {num_stories} stories from HackerNews.") + top_stories: RunResponse = self.hn_agent.run(num_stories=num_stories) + if top_stories is None or not top_stories.content: + yield RunResponse( + run_id=self.run_id, content="Sorry, could not get the top stories." + ) + return + + logger.info("Reading each story and writing a report.") + yield from self.writer.run(top_stories.content, stream=True) + + +if __name__ == "__main__": + # Run workflow + report: Iterator[RunResponse] = HackerNewsReporter( + storage=JsonStorage(dir_path="tmp/workflow_sessions_json"), debug_mode=False + ).run(num_stories=5) + # Print the report + pprint_run_response(report, markdown=True, show_time=True) diff --git a/cookbook/storage/mongodb_storage/__init__.py b/cookbook/storage/mongodb_storage/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/cookbook/agent_concepts/storage/mongodb_storage.py b/cookbook/storage/mongodb_storage/mongodb_storage_for_agent.py similarity index 88% rename from cookbook/agent_concepts/storage/mongodb_storage.py rename to cookbook/storage/mongodb_storage/mongodb_storage_for_agent.py index a4ce1358e3..8be58b4d73 100644 --- a/cookbook/agent_concepts/storage/mongodb_storage.py +++ b/cookbook/storage/mongodb_storage/mongodb_storage_for_agent.py @@ -7,14 +7,14 @@ """ from agno.agent import Agent -from agno.storage.agent.mongodb import MongoDbAgentStorage +from agno.storage.mongodb import MongoDbStorage from agno.tools.duckduckgo import DuckDuckGoTools # MongoDB connection settings db_url = "mongodb://localhost:27017" agent = Agent( - storage=MongoDbAgentStorage( + storage=MongoDbStorage( collection_name="agent_sessions", db_url=db_url, db_name="agno" ), tools=[DuckDuckGoTools()], diff --git a/cookbook/storage/mongodb_storage/mongodb_storage_for_workflow.py b/cookbook/storage/mongodb_storage/mongodb_storage_for_workflow.py new file mode 100644 index 0000000000..2f2c3692df --- /dev/null +++ b/cookbook/storage/mongodb_storage/mongodb_storage_for_workflow.py @@ -0,0 +1,94 @@ +import json +from typing import Iterator + +import httpx +from agno.agent import Agent +from agno.run.response import RunResponse +from agno.storage.mongodb import MongoDbStorage +from agno.tools.newspaper4k import Newspaper4kTools +from agno.utils.log import logger +from agno.utils.pprint import pprint_run_response +from agno.workflow import Workflow + +db_url = "mongodb://localhost:27017" + + +class HackerNewsReporter(Workflow): + description: str = ( + "Get the top stories from Hacker News and write a report on them." + ) + + hn_agent: Agent = Agent( + description="Get the top stories from hackernews. " + "Share all possible information, including url, score, title and summary if available.", + show_tool_calls=True, + ) + + writer: Agent = Agent( + tools=[Newspaper4kTools()], + description="Write an engaging report on the top stories from hackernews.", + instructions=[ + "You will be provided with top stories and their links.", + "Carefully read each article and think about the contents", + "Then generate a final New York Times worthy article", + "Break the article into sections and provide key takeaways at the end.", + "Make sure the title is catchy and engaging.", + "Share score, title, url and summary of every article.", + "Give the section relevant titles and provide details/facts/processes in each section." + "Ignore articles that you cannot read or understand.", + "REMEMBER: you are writing for the New York Times, so the quality of the article is important.", + ], + ) + + def get_top_hackernews_stories(self, num_stories: int = 10) -> str: + """Use this function to get top stories from Hacker News. + + Args: + num_stories (int): Number of stories to return. Defaults to 10. + + Returns: + str: JSON string of top stories. + """ + + # Fetch top story IDs + response = httpx.get("https://hacker-news.firebaseio.com/v0/topstories.json") + story_ids = response.json() + + # Fetch story details + stories = [] + for story_id in story_ids[:num_stories]: + story_response = httpx.get( + f"https://hacker-news.firebaseio.com/v0/item/{story_id}.json" + ) + story = story_response.json() + story["username"] = story["by"] + stories.append(story) + return json.dumps(stories) + + def run(self, num_stories: int = 5) -> Iterator[RunResponse]: + # Set the tools for hn_agent here to avoid circular reference + self.hn_agent.tools = [self.get_top_hackernews_stories] + + logger.info(f"Getting top {num_stories} stories from HackerNews.") + top_stories: RunResponse = self.hn_agent.run(num_stories=num_stories) + if top_stories is None or not top_stories.content: + yield RunResponse( + run_id=self.run_id, content="Sorry, could not get the top stories." + ) + return + + logger.info("Reading each story and writing a report.") + yield from self.writer.run(top_stories.content, stream=True) + + +if __name__ == "__main__": + # Run workflow + storage = MongoDbStorage( + collection_name="agent_sessions", db_url=db_url, db_name="agno" + ) + storage.drop() + report: Iterator[RunResponse] = HackerNewsReporter( + storage=storage, debug_mode=False + ).run(num_stories=5) + # Print the report + pprint_run_response(report, markdown=True, show_time=True) diff --git a/cookbook/storage/postgres_storage/__init__.py b/cookbook/storage/postgres_storage/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/cookbook/agent_concepts/storage/postgres_storage.py b/cookbook/storage/postgres_storage/postgres_storage_for_agent.py similarity index 75% rename from cookbook/agent_concepts/storage/postgres_storage.py rename to cookbook/storage/postgres_storage/postgres_storage_for_agent.py index 5bf06173e7..9467a90314 100644 --- a/cookbook/agent_concepts/storage/postgres_storage.py +++ b/cookbook/storage/postgres_storage/postgres_storage_for_agent.py @@ -1,13 +1,13 @@ """Run `pip install duckduckgo-search sqlalchemy openai` to install dependencies.""" from agno.agent import Agent -from agno.storage.agent.postgres import PostgresAgentStorage +from agno.storage.postgres import PostgresStorage from agno.tools.duckduckgo import DuckDuckGoTools db_url = "postgresql+psycopg://ai:ai@localhost:5532/ai" agent = Agent( - storage=PostgresAgentStorage(table_name="agent_sessions", db_url=db_url), + storage=PostgresStorage(table_name="agent_sessions", db_url=db_url), tools=[DuckDuckGoTools()], add_history_to_messages=True, ) diff --git a/cookbook/storage/postgres_storage/postgres_storage_for_workflow.py b/cookbook/storage/postgres_storage/postgres_storage_for_workflow.py new file mode 100644 index 0000000000..98ac69ac2c --- /dev/null +++ b/cookbook/storage/postgres_storage/postgres_storage_for_workflow.py @@ -0,0 +1,92 @@ +import json +from typing import Iterator + +import httpx +from agno.agent import Agent +from agno.run.response import RunResponse +from agno.storage.postgres import PostgresStorage +from agno.tools.newspaper4k import Newspaper4kTools +from agno.utils.log import logger +from agno.utils.pprint import pprint_run_response +from agno.workflow import Workflow + +db_url = "postgresql+psycopg://ai:ai@localhost:5532/ai" + + +class HackerNewsReporter(Workflow): + description: str = ( + "Get the top stories from Hacker News and write a report on them." + ) + + hn_agent: Agent = Agent( + description="Get the top stories from hackernews. " + "Share all possible information, including url, score, title and summary if available.", + show_tool_calls=True, + ) + + writer: Agent = Agent( + tools=[Newspaper4kTools()], + description="Write an engaging report on the top stories from hackernews.", + instructions=[ + "You will be provided with top stories and their links.", + "Carefully read each article and think about the contents", + "Then generate a final New York Times worthy article", + "Break the article into sections and provide key takeaways at the end.", + "Make sure the title is catchy and engaging.", + "Share score, title, url and summary of every article.", + "Give the section relevant titles and provide details/facts/processes in each section." + "Ignore articles that you cannot read or understand.", + "REMEMBER: you are writing for the New York Times, so the quality of the article is important.", + ], + ) + + def get_top_hackernews_stories(self, num_stories: int = 10) -> str: + """Use this function to get top stories from Hacker News. + + Args: + num_stories (int): Number of stories to return. Defaults to 10. + + Returns: + str: JSON string of top stories. + """ + + # Fetch top story IDs + response = httpx.get("https://hacker-news.firebaseio.com/v0/topstories.json") + story_ids = response.json() + + # Fetch story details + stories = [] + for story_id in story_ids[:num_stories]: + story_response = httpx.get( + f"https://hacker-news.firebaseio.com/v0/item/{story_id}.json" + ) + story = story_response.json() + story["username"] = story["by"] + stories.append(story) + return json.dumps(stories) + + def run(self, num_stories: int = 5) -> Iterator[RunResponse]: + # Set the tools for hn_agent here to avoid circular reference + self.hn_agent.tools = [self.get_top_hackernews_stories] + + logger.info(f"Getting top {num_stories} stories from HackerNews.") + top_stories: RunResponse = self.hn_agent.run(num_stories=num_stories) + if top_stories is None or not top_stories.content: + yield RunResponse( + run_id=self.run_id, content="Sorry, could not get the top stories." + ) + return + + logger.info("Reading each story and writing a report.") + yield from self.writer.run(top_stories.content, stream=True) + + +if __name__ == "__main__": + # Run workflow + storage = PostgresStorage(table_name="agent_sessions", db_url=db_url) + storage.drop() + report: Iterator[RunResponse] = HackerNewsReporter( + storage=storage, debug_mode=False + ).run(num_stories=5) + # Print the report + pprint_run_response(report, markdown=True, show_time=True) diff --git a/cookbook/storage/singlestore_storage/__init__.py b/cookbook/storage/singlestore_storage/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/cookbook/agent_concepts/storage/singlestore_storage.py b/cookbook/storage/singlestore_storage/singlestore_storage_for_agent.py similarity index 70% rename from cookbook/agent_concepts/storage/singlestore_storage.py rename to cookbook/storage/singlestore_storage/singlestore_storage_for_agent.py index c2be52bc7d..6b59e3ffd4 100644 --- a/cookbook/agent_concepts/storage/singlestore_storage.py +++ b/cookbook/storage/singlestore_storage/singlestore_storage_for_agent.py @@ -1,10 +1,12 @@ """Run `pip install duckduckgo-search sqlalchemy openai` to install dependencies.""" +import os from os import getenv from agno.agent import Agent -from agno.storage.agent.singlestore import SingleStoreAgentStorage +from agno.storage.singlestore import SingleStoreStorage from agno.tools.duckduckgo import DuckDuckGoTools +from agno.utils.certs import download_cert from sqlalchemy.engine import create_engine # Configure SingleStore DB connection @@ -15,6 +17,17 @@ DATABASE = getenv("SINGLESTORE_DATABASE") SSL_CERT = getenv("SINGLESTORE_SSL_CERT", None) + +# Download the certificate if SSL_CERT is not provided +if not SSL_CERT: + SSL_CERT = download_cert( + cert_url="https://portal.singlestore.com/static/ca/singlestore_bundle.pem", + filename="singlestore_bundle.pem", + ) + if SSL_CERT: + os.environ["SINGLESTORE_SSL_CERT"] = SSL_CERT + + # SingleStore DB URL db_url = ( f"mysql+pymysql://{USERNAME}:{PASSWORD}@{HOST}:{PORT}/{DATABASE}?charset=utf8mb4" @@ -27,7 +40,7 @@ # Create an agent with SingleStore storage agent = Agent( - storage=SingleStoreAgentStorage( + storage=SingleStoreStorage( table_name="agent_sessions", db_engine=db_engine, schema=DATABASE ), tools=[DuckDuckGoTools()], diff --git a/cookbook/storage/singlestore_storage/singlestore_storage_for_workflow.py b/cookbook/storage/singlestore_storage/singlestore_storage_for_workflow.py new file mode 100644 index 0000000000..df53d18f12 --- /dev/null +++ b/cookbook/storage/singlestore_storage/singlestore_storage_for_workflow.py @@ -0,0 +1,121 @@ +import json +import os +from os import getenv +from typing import Iterator + +import httpx +from agno.agent import Agent +from agno.run.response import RunResponse +from agno.storage.singlestore import SingleStoreStorage +from agno.tools.newspaper4k import Newspaper4kTools +from agno.utils.certs import download_cert +from agno.utils.log import logger +from agno.utils.pprint import pprint_run_response +from agno.workflow import Workflow +from sqlalchemy.engine import create_engine + + +class HackerNewsReporter(Workflow): + description: str = ( + "Get the top stories from Hacker News and write a report on them." + ) + + hn_agent: Agent = Agent( + description="Get the top stories from hackernews. " + "Share all possible information, including url, score, title and summary if available.", + show_tool_calls=True, + ) + + writer: Agent = Agent( + tools=[Newspaper4kTools()], + description="Write an engaging report on the top stories from hackernews.", + instructions=[ + "You will be provided with top stories and their links.", + "Carefully read each article and think about the contents", + "Then generate a final New York Times worthy article", + "Break the article into sections and provide key takeaways at the end.", + "Make sure the title is catchy and engaging.", + "Share score, title, url and summary of every article.", + "Give the section relevant titles and provide details/facts/processes in each section." + "Ignore articles that you cannot read or understand.", + "REMEMBER: you are writing for the New York Times, so the quality of the article is important.", + ], + ) + + def get_top_hackernews_stories(self, num_stories: int = 10) -> str: + """Use this function to get top stories from Hacker News. + + Args: + num_stories (int): Number of stories to return. Defaults to 10. + + Returns: + str: JSON string of top stories. + """ + + # Fetch top story IDs + response = httpx.get("https://hacker-news.firebaseio.com/v0/topstories.json") + story_ids = response.json() + + # Fetch story details + stories = [] + for story_id in story_ids[:num_stories]: + story_response = httpx.get( + f"https://hacker-news.firebaseio.com/v0/item/{story_id}.json" + ) + story = story_response.json() + story["username"] = story["by"] + stories.append(story) + return json.dumps(stories) + + def run(self, num_stories: int = 5) -> Iterator[RunResponse]: + # Set the tools for hn_agent here to avoid circular reference + self.hn_agent.tools = [self.get_top_hackernews_stories] + + logger.info(f"Getting top {num_stories} stories from HackerNews.") + top_stories: RunResponse = self.hn_agent.run(num_stories=num_stories) + if top_stories is None or not top_stories.content: + yield RunResponse( + run_id=self.run_id, content="Sorry, could not get the top stories." + ) + return + + logger.info("Reading each story and writing a report.") + yield from self.writer.run(top_stories.content, stream=True) + + +if __name__ == "__main__": + USERNAME = getenv("SINGLESTORE_USERNAME") + PASSWORD = getenv("SINGLESTORE_PASSWORD") + HOST = getenv("SINGLESTORE_HOST") + PORT = getenv("SINGLESTORE_PORT") + DATABASE = getenv("SINGLESTORE_DATABASE") + SSL_CERT = getenv("SINGLESTORE_SSL_CERT", None) + + # Download the certificate if SSL_CERT is not provided + if not SSL_CERT: + SSL_CERT = download_cert( + cert_url="https://portal.singlestore.com/static/ca/singlestore_bundle.pem", + filename="singlestore_bundle.pem", + ) + if SSL_CERT: + os.environ["SINGLESTORE_SSL_CERT"] = SSL_CERT + + # SingleStore DB URL + db_url = f"mysql+pymysql://{USERNAME}:{PASSWORD}@{HOST}:{PORT}/{DATABASE}?charset=utf8mb4" + if SSL_CERT: + db_url += f"&ssl_ca={SSL_CERT}&ssl_verify_cert=true" + + # Create a DB engine + db_engine = create_engine(db_url) + # Run workflow + report: Iterator[RunResponse] = HackerNewsReporter( + storage=SingleStoreStorage( + table_name="workflow_sessions", + mode="workflow", + db_engine=db_engine, + schema=DATABASE, + ), + debug_mode=False, + ).run(num_stories=5) + # Print the report + pprint_run_response(report, markdown=True, show_time=True) diff --git a/cookbook/storage/sqllite_storage/__init__.py b/cookbook/storage/sqllite_storage/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/cookbook/agent_concepts/storage/sqlite_storage.py b/cookbook/storage/sqllite_storage/sqlite_storage_for_agent.py similarity index 72% rename from cookbook/agent_concepts/storage/sqlite_storage.py rename to cookbook/storage/sqllite_storage/sqlite_storage_for_agent.py index 3ea55ac97a..79e6101a6b 100644 --- a/cookbook/agent_concepts/storage/sqlite_storage.py +++ b/cookbook/storage/sqllite_storage/sqlite_storage_for_agent.py @@ -1,11 +1,11 @@ """Run `pip install duckduckgo-search sqlalchemy openai` to install dependencies.""" from agno.agent import Agent -from agno.storage.agent.sqlite import SqliteAgentStorage +from agno.storage.sqlite import SqliteStorage from agno.tools.duckduckgo import DuckDuckGoTools agent = Agent( - storage=SqliteAgentStorage(table_name="agent_sessions", db_file="tmp/data.db"), + storage=SqliteStorage(table_name="agent_sessions", db_file="tmp/data.db"), tools=[DuckDuckGoTools()], add_history_to_messages=True, ) diff --git a/cookbook/storage/sqllite_storage/sqlite_storage_for_workflow.py b/cookbook/storage/sqllite_storage/sqlite_storage_for_workflow.py new file mode 100644 index 0000000000..14e10270fb --- /dev/null +++ b/cookbook/storage/sqllite_storage/sqlite_storage_for_workflow.py @@ -0,0 +1,89 @@ +import json +from typing import Iterator + +import httpx +from agno.agent import Agent +from agno.run.response import RunResponse +from agno.storage.sqlite import SqliteStorage +from agno.tools.newspaper4k import Newspaper4kTools +from agno.utils.log import logger +from agno.utils.pprint import pprint_run_response +from agno.workflow import Workflow + + +class HackerNewsReporter(Workflow): + description: str = ( + "Get the top stories from Hacker News and write a report on them." + ) + + hn_agent: Agent = Agent( + description="Get the top stories from hackernews. " + "Share all possible information, including url, score, title and summary if available.", + show_tool_calls=True, + ) + + writer: Agent = Agent( + tools=[Newspaper4kTools()], + description="Write an engaging report on the top stories from hackernews.", + instructions=[ + "You will be provided with top stories and their links.", + "Carefully read each article and think about the contents", + "Then generate a final New York Times worthy article", + "Break the article into sections and provide key takeaways at the end.", + "Make sure the title is catchy and engaging.", + "Share score, title, url and summary of every article.", + "Give the section relevant titles and provide details/facts/processes in each section." + "Ignore articles that you cannot read or understand.", + "REMEMBER: you are writing for the New York Times, so the quality of the article is important.", + ], + ) + + def get_top_hackernews_stories(self, num_stories: int = 10) -> str: + """Use this function to get top stories from Hacker News. + + Args: + num_stories (int): Number of stories to return. Defaults to 10. + + Returns: + str: JSON string of top stories. + """ + + # Fetch top story IDs + response = httpx.get("https://hacker-news.firebaseio.com/v0/topstories.json") + story_ids = response.json() + + # Fetch story details + stories = [] + for story_id in story_ids[:num_stories]: + story_response = httpx.get( + f"https://hacker-news.firebaseio.com/v0/item/{story_id}.json" + ) + story = story_response.json() + story["username"] = story["by"] + stories.append(story) + return json.dumps(stories) + + def run(self, num_stories: int = 5) -> Iterator[RunResponse]: + # Set the tools for hn_agent here to avoid circular reference + self.hn_agent.tools = [self.get_top_hackernews_stories] + + logger.info(f"Getting top {num_stories} stories from HackerNews.") + top_stories: RunResponse = self.hn_agent.run(num_stories=num_stories) + if top_stories is None or not top_stories.content: + yield RunResponse( + run_id=self.run_id, content="Sorry, could not get the top stories." + ) + return + + logger.info("Reading each story and writing a report.") + yield from self.writer.run(top_stories.content, stream=True) + + +if __name__ == "__main__": + # Run workflow + storage = SqliteStorage(table_name="workflow_sessions", db_file="tmp/data.db") + report: Iterator[RunResponse] = HackerNewsReporter( + storage=storage, debug_mode=False + ).run(num_stories=5) + # Print the report + pprint_run_response(report, markdown=True, show_time=True) diff --git a/cookbook/storage/yaml_storage/__init__.py b/cookbook/storage/yaml_storage/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/cookbook/agent_concepts/storage/yaml_storage.py b/cookbook/storage/yaml_storage/yaml_storage_for_agent.py similarity index 74% rename from cookbook/agent_concepts/storage/yaml_storage.py rename to cookbook/storage/yaml_storage/yaml_storage_for_agent.py index c4ab78560f..564dba2e68 100644 --- a/cookbook/agent_concepts/storage/yaml_storage.py +++ b/cookbook/storage/yaml_storage/yaml_storage_for_agent.py @@ -1,11 +1,11 @@ """Run `pip install duckduckgo-search openai` to install dependencies.""" from agno.agent import Agent -from agno.storage.agent.yaml import YamlAgentStorage +from agno.storage.yaml import YamlStorage from agno.tools.duckduckgo import DuckDuckGoTools agent = Agent( - storage=YamlAgentStorage(dir_path="tmp/agent_sessions_yaml"), + storage=YamlStorage(dir_path="tmp/agent_sessions_yaml"), tools=[DuckDuckGoTools()], add_history_to_messages=True, ) diff --git a/cookbook/storage/yaml_storage/yaml_storage_for_workflow.py b/cookbook/storage/yaml_storage/yaml_storage_for_workflow.py new file mode 100644 index 0000000000..095974afeb --- /dev/null +++ b/cookbook/storage/yaml_storage/yaml_storage_for_workflow.py @@ -0,0 +1,88 @@ +import json +from typing import Iterator + +import httpx +from agno.agent import Agent +from agno.run.response import RunResponse +from agno.storage.yaml import YamlStorage +from agno.tools.newspaper4k import Newspaper4kTools +from agno.utils.log import logger +from agno.utils.pprint import pprint_run_response +from agno.workflow import Workflow + + +class HackerNewsReporter(Workflow): + description: str = ( + "Get the top stories from Hacker News and write a report on them." + ) + + hn_agent: Agent = Agent( + description="Get the top stories from hackernews. " + "Share all possible information, including url, score, title and summary if available.", + show_tool_calls=True, + ) + + writer: Agent = Agent( + tools=[Newspaper4kTools()], + description="Write an engaging report on the top stories from hackernews.", + instructions=[ + "You will be provided with top stories and their links.", + "Carefully read each article and think about the contents", + "Then generate a final New York Times worthy article", + "Break the article into sections and provide key takeaways at the end.", + "Make sure the title is catchy and engaging.", + "Share score, title, url and summary of every article.", + "Give the section relevant titles and provide details/facts/processes in each section." + "Ignore articles that you cannot read or understand.", + "REMEMBER: you are writing for the New York Times, so the quality of the article is important.", + ], + ) + + def get_top_hackernews_stories(self, num_stories: int = 10) -> str: + """Use this function to get top stories from Hacker News. + + Args: + num_stories (int): Number of stories to return. Defaults to 10. + + Returns: + str: JSON string of top stories. + """ + + # Fetch top story IDs + response = httpx.get("https://hacker-news.firebaseio.com/v0/topstories.json") + story_ids = response.json() + + # Fetch story details + stories = [] + for story_id in story_ids[:num_stories]: + story_response = httpx.get( + f"https://hacker-news.firebaseio.com/v0/item/{story_id}.json" + ) + story = story_response.json() + story["username"] = story["by"] + stories.append(story) + return json.dumps(stories) + + def run(self, num_stories: int = 5) -> Iterator[RunResponse]: + # Set the tools for hn_agent here to avoid circular reference + self.hn_agent.tools = [self.get_top_hackernews_stories] + + logger.info(f"Getting top {num_stories} stories from HackerNews.") + top_stories: RunResponse = self.hn_agent.run(num_stories=num_stories) + if top_stories is None or not top_stories.content: + yield RunResponse( + run_id=self.run_id, content="Sorry, could not get the top stories." + ) + return + + logger.info("Reading each story and writing a report.") + yield from self.writer.run(top_stories.content, stream=True) + + +if __name__ == "__main__": + # Run workflow + report: Iterator[RunResponse] = HackerNewsReporter( + storage=YamlStorage(dir_path="tmp/workflow_sessions_yaml"), debug_mode=False + ).run(num_stories=5) + # Print the report + pprint_run_response(report, markdown=True, show_time=True) diff --git a/cookbook/workflows/blog_post_generator.py b/cookbook/workflows/blog_post_generator.py index 981789461f..79bed6eb9f 100644 --- a/cookbook/workflows/blog_post_generator.py +++ b/cookbook/workflows/blog_post_generator.py @@ -32,7 +32,7 @@ from agno.agent import Agent from agno.models.openai import OpenAIChat -from agno.storage.workflow.sqlite import SqliteWorkflowStorage +from agno.storage.sqlite import SqliteStorage from agno.tools.duckduckgo import DuckDuckGoTools from agno.tools.newspaper4k import Newspaper4kTools from agno.utils.log import logger @@ -412,7 +412,7 @@ def scrape_articles( # - Sets up SQLite storage for caching results generate_blog_post = BlogPostGenerator( session_id=f"generate-blog-post-on-{url_safe_topic}", - storage=SqliteWorkflowStorage( + storage=SqliteStorage( table_name="generate_blog_post_workflows", db_file="tmp/agno_workflows.db", ), diff --git a/cookbook/workflows/investment_report_generator.py b/cookbook/workflows/investment_report_generator.py index ffebca1cba..c062bf4b46 100644 --- a/cookbook/workflows/investment_report_generator.py +++ b/cookbook/workflows/investment_report_generator.py @@ -32,7 +32,7 @@ from typing import Iterator from agno.agent import Agent, RunResponse -from agno.storage.workflow.sqlite import SqliteWorkflowStorage +from agno.storage.sqlite import SqliteStorage from agno.tools.yfinance import YFinanceTools from agno.utils.log import logger from agno.utils.pprint import pprint_run_response @@ -216,7 +216,7 @@ def run(self, companies: str) -> Iterator[RunResponse]: # Initialize the investment analyst workflow investment_report_generator = InvestmentReportGenerator( session_id=f"investment-report-{url_safe_companies}", - storage=SqliteWorkflowStorage( + storage=SqliteStorage( table_name="investment_report_workflows", db_file="tmp/agno_workflows.db", ), diff --git a/cookbook/workflows/personalized_email_generator.py b/cookbook/workflows/personalized_email_generator.py index 09ede8a225..dd0bf6c23e 100644 --- a/cookbook/workflows/personalized_email_generator.py +++ b/cookbook/workflows/personalized_email_generator.py @@ -56,7 +56,7 @@ from agno.agent import Agent from agno.models.openai import OpenAIChat -from agno.storage.workflow.sqlite import SqliteWorkflowStorage +from agno.storage.sqlite import SqliteStorage from agno.tools.exa import ExaTools from agno.utils.log import logger from agno.utils.pprint import pprint_run_response @@ -439,7 +439,7 @@ def main(): # Create workflow with SQLite storage workflow = PersonalisedEmailGenerator( session_id="personalized-email-generator", - storage=SqliteWorkflowStorage( + storage=SqliteStorage( table_name="personalized_email_workflows", db_file="tmp/agno_workflows.db", ), diff --git a/cookbook/workflows/product_manager/product_manager.py b/cookbook/workflows/product_manager/product_manager.py index a2817f877b..ea9a165e37 100644 --- a/cookbook/workflows/product_manager/product_manager.py +++ b/cookbook/workflows/product_manager/product_manager.py @@ -4,7 +4,7 @@ from agno.agent.agent import Agent from agno.run.response import RunEvent, RunResponse -from agno.storage.workflow.postgres import PostgresWorkflowStorage +from agno.storage.postgres import PostgresStorage from agno.tools.linear import LinearTools from agno.tools.slack import SlackTools from agno.utils.log import logger @@ -158,7 +158,7 @@ def run( # Create the workflow product_manager = ProductManagerWorkflow( session_id="product-manager", - storage=PostgresWorkflowStorage( + storage=PostgresStorage( table_name="product_manager_workflows", db_url="postgresql+psycopg://ai:ai@localhost:5532/ai", ), diff --git a/cookbook/workflows/startup_idea_validator.py b/cookbook/workflows/startup_idea_validator.py index 980374a777..29f9b9dd85 100644 --- a/cookbook/workflows/startup_idea_validator.py +++ b/cookbook/workflows/startup_idea_validator.py @@ -50,7 +50,7 @@ from agno.agent import Agent from agno.models.openai import OpenAIChat -from agno.storage.workflow.sqlite import SqliteWorkflowStorage +from agno.storage.sqlite import SqliteStorage from agno.tools.googlesearch import GoogleSearchTools from agno.utils.log import logger from agno.utils.pprint import pprint_run_response @@ -264,7 +264,7 @@ def run(self, startup_idea: str) -> Iterator[RunResponse]: startup_idea_validator = StartupIdeaValidator( description="Startup Idea Validator", session_id=f"validate-startup-idea-{url_safe_idea}", - storage=SqliteWorkflowStorage( + storage=SqliteStorage( table_name="validate_startup_ideas_workflow", db_file="tmp/agno_workflows.db", ), diff --git a/cookbook/workflows/workflows_playground.py b/cookbook/workflows/workflows_playground.py index 8a7a559d86..77abd73a03 100644 --- a/cookbook/workflows/workflows_playground.py +++ b/cookbook/workflows/workflows_playground.py @@ -4,7 +4,7 @@ """ from agno.playground import Playground, serve_playground_app -from agno.storage.workflow.sqlite import SqliteWorkflowStorage +from agno.storage.sqlite import SqliteStorage # Import the workflows from blog_post_generator import BlogPostGenerator @@ -18,14 +18,14 @@ blog_post_generator = BlogPostGenerator( workflow_id="generate-blog-post", - storage=SqliteWorkflowStorage( + storage=SqliteStorage( table_name="generate_blog_post_workflows", db_file="tmp/agno_workflows.db", ), ) personalised_email_generator = PersonalisedEmailGenerator( workflow_id="personalized-email-generator", - storage=SqliteWorkflowStorage( + storage=SqliteStorage( table_name="personalized_email_workflows", db_file="tmp/agno_workflows.db", ), @@ -33,7 +33,7 @@ investment_report_generator = InvestmentReportGenerator( workflow_id="generate-investment-report", - storage=SqliteWorkflowStorage( + storage=SqliteStorage( table_name="investment_report_workflows", db_file="tmp/agno_workflows.db", ), @@ -41,7 +41,7 @@ startup_idea_validator = StartupIdeaValidator( workflow_id="validate-startup-idea", - storage=SqliteWorkflowStorage( + storage=SqliteStorage( table_name="validate_startup_ideas_workflow", db_file="tmp/agno_workflows.db", ), diff --git a/libs/agno/agno/agent/__init__.py b/libs/agno/agno/agent/__init__.py index e326efa9b4..3e38a4a324 100644 --- a/libs/agno/agno/agent/__init__.py +++ b/libs/agno/agno/agent/__init__.py @@ -3,10 +3,10 @@ AgentKnowledge, AgentMemory, AgentSession, - AgentStorage, Function, Message, RunEvent, RunResponse, + Storage, Toolkit, ) diff --git a/libs/agno/agno/agent/agent.py b/libs/agno/agno/agent/agent.py index 6794c75775..b29e14b1cb 100644 --- a/libs/agno/agno/agent/agent.py +++ b/libs/agno/agno/agent/agent.py @@ -35,8 +35,8 @@ from agno.reasoning.step import NextAction, ReasoningStep, ReasoningSteps from agno.run.messages import RunMessages from agno.run.response import RunEvent, RunResponse, RunResponseExtraData -from agno.storage.agent.base import AgentStorage -from agno.storage.agent.session import AgentSession +from agno.storage.base import Storage +from agno.storage.session.agent import AgentSession from agno.tools.function import Function from agno.tools.toolkit import Toolkit from agno.utils.log import logger, set_log_level_to_debug, set_log_level_to_info @@ -98,7 +98,7 @@ class Agent: references_format: Literal["json", "yaml"] = "json" # --- Agent Storage --- - storage: Optional[AgentStorage] = None + storage: Optional[Storage] = None # Extra data stored with this agent extra_data: Optional[Dict[str, Any]] = None @@ -250,7 +250,7 @@ def __init__( add_references: bool = False, retriever: Optional[Callable[..., Optional[List[Dict]]]] = None, references_format: Literal["json", "yaml"] = "json", - storage: Optional[AgentStorage] = None, + storage: Optional[Storage] = None, extra_data: Optional[Dict[str, Any]] = None, tools: Optional[List[Union[Toolkit, Callable, Function]]] = None, show_tool_calls: bool = False, @@ -425,6 +425,13 @@ def set_debug(self) -> None: else: set_log_level_to_info() + def set_storage_mode(self): + if self.storage is not None: + if self.storage.mode == "workflow": + logger.warning("You cannot use storage in both workflow and agent mode") + + self.storage.mode = "agent" + def set_monitoring(self) -> None: """Override monitoring and telemetry settings based on environment variables.""" @@ -438,6 +445,7 @@ def set_monitoring(self) -> None: self.telemetry = telemetry_env.lower() == "true" def initialize_agent(self) -> None: + self.set_storage_mode() self.set_debug() self.set_agent_id() self.set_session_id() @@ -1767,7 +1775,7 @@ def read_from_storage(self) -> Optional[AgentSession]: Optional[AgentSession]: The loaded AgentSession or None if not found. """ if self.storage is not None and self.session_id is not None: - self.agent_session = self.storage.read(session_id=self.session_id) + self.agent_session = cast(AgentSession, self.storage.read(session_id=self.session_id)) if self.agent_session is not None: self.load_agent_session(session=self.agent_session) self.load_user_memories() @@ -1780,7 +1788,7 @@ def write_to_storage(self) -> Optional[AgentSession]: Optional[AgentSession]: The saved AgentSession or None if not saved. """ if self.storage is not None: - self.agent_session = self.storage.upsert(session=self.get_agent_session()) + self.agent_session = cast(AgentSession, self.storage.upsert(session=self.get_agent_session())) return self.agent_session def add_introduction(self, introduction: str) -> None: diff --git a/libs/agno/agno/playground/async_router.py b/libs/agno/agno/playground/async_router.py index cc142d58a4..2621713ee1 100644 --- a/libs/agno/agno/playground/async_router.py +++ b/libs/agno/agno/playground/async_router.py @@ -27,8 +27,8 @@ WorkflowsGetResponse, ) from agno.run.response import RunEvent -from agno.storage.agent.session import AgentSession -from agno.storage.workflow.session import WorkflowSession +from agno.storage.session.agent import AgentSession +from agno.storage.session.workflow import WorkflowSession from agno.utils.log import logger from agno.workflow.workflow import Workflow @@ -302,7 +302,7 @@ async def get_all_agent_sessions(agent_id: str, user_id: Optional[str] = Query(N return JSONResponse(status_code=404, content="Agent does not have storage enabled.") agent_sessions: List[AgentSessionsResponse] = [] - all_agent_sessions: List[AgentSession] = agent.storage.get_all_sessions(user_id=user_id) + all_agent_sessions: List[AgentSession] = agent.storage.get_all_sessions(user_id=user_id) # type: ignore for session in all_agent_sessions: title = get_session_title(session) agent_sessions.append( @@ -325,7 +325,7 @@ async def get_agent_session(agent_id: str, session_id: str, user_id: Optional[st if agent.storage is None: return JSONResponse(status_code=404, content="Agent does not have storage enabled.") - agent_session: Optional[AgentSession] = agent.storage.read(session_id, user_id) + agent_session: Optional[AgentSession] = agent.storage.read(session_id, user_id) # type: ignore if agent_session is None: return JSONResponse(status_code=404, content="Session not found.") @@ -340,7 +340,7 @@ async def rename_agent_session(agent_id: str, session_id: str, body: AgentRename if agent.storage is None: return JSONResponse(status_code=404, content="Agent does not have storage enabled.") - all_agent_sessions: List[AgentSession] = agent.storage.get_all_sessions(user_id=body.user_id) + all_agent_sessions: List[AgentSession] = agent.storage.get_all_sessions(user_id=body.user_id) # type: ignore for session in all_agent_sessions: if session.session_id == session_id: agent.session_id = session_id @@ -358,7 +358,7 @@ async def delete_agent_session(agent_id: str, session_id: str, user_id: Optional if agent.storage is None: return JSONResponse(status_code=404, content="Agent does not have storage enabled.") - all_agent_sessions: List[AgentSession] = agent.storage.get_all_sessions(user_id=user_id) + all_agent_sessions: List[AgentSession] = agent.storage.get_all_sessions(user_id=user_id) # type: ignore for session in all_agent_sessions: if session.session_id == session_id: agent.delete_session(session_id) @@ -441,7 +441,7 @@ async def get_all_workflow_sessions(workflow_id: str, user_id: Optional[str] = Q try: all_workflow_sessions: List[WorkflowSession] = workflow.storage.get_all_sessions( user_id=user_id, workflow_id=workflow_id - ) + ) # type: ignore except Exception as e: raise HTTPException(status_code=500, detail=f"Error retrieving sessions: {str(e)}") @@ -471,7 +471,7 @@ async def get_workflow_session( # Retrieve the specific session try: - workflow_session: Optional[WorkflowSession] = workflow.storage.read(session_id, user_id) + workflow_session: Optional[WorkflowSession] = workflow.storage.read(session_id, user_id) # type: ignore except Exception as e: raise HTTPException(status_code=500, detail=f"Error retrieving session: {str(e)}") diff --git a/libs/agno/agno/playground/operator.py b/libs/agno/agno/playground/operator.py index aa60b8171d..9b8e09d9fa 100644 --- a/libs/agno/agno/playground/operator.py +++ b/libs/agno/agno/playground/operator.py @@ -1,8 +1,8 @@ from typing import List, Optional from agno.agent.agent import Agent, AgentRun, Function, Toolkit -from agno.storage.agent.session import AgentSession -from agno.storage.workflow.session import WorkflowSession +from agno.storage.session.agent import AgentSession +from agno.storage.session.workflow import WorkflowSession from agno.utils.log import logger from agno.workflow.workflow import Workflow diff --git a/libs/agno/agno/playground/sync_router.py b/libs/agno/agno/playground/sync_router.py index 69db41b8c4..907adb1c86 100644 --- a/libs/agno/agno/playground/sync_router.py +++ b/libs/agno/agno/playground/sync_router.py @@ -27,8 +27,8 @@ WorkflowsGetResponse, ) from agno.run.response import RunEvent -from agno.storage.agent.session import AgentSession -from agno.storage.workflow.session import WorkflowSession +from agno.storage.session.agent import AgentSession +from agno.storage.session.workflow import WorkflowSession from agno.utils.log import logger from agno.workflow.workflow import Workflow @@ -302,7 +302,7 @@ def get_user_agent_sessions(agent_id: str, user_id: Optional[str] = Query(None, return JSONResponse(status_code=404, content="Agent does not have storage enabled.") agent_sessions: List[AgentSessionsResponse] = [] - all_agent_sessions: List[AgentSession] = agent.storage.get_all_sessions(user_id=user_id) + all_agent_sessions: List[AgentSession] = agent.storage.get_all_sessions(user_id=user_id) # type: ignore for session in all_agent_sessions: title = get_session_title(session) agent_sessions.append( @@ -325,7 +325,7 @@ def get_user_agent_session(agent_id: str, session_id: str, user_id: Optional[str if agent.storage is None: return JSONResponse(status_code=404, content="Agent does not have storage enabled.") - agent_session: Optional[AgentSession] = agent.storage.read(session_id) + agent_session: Optional[AgentSession] = agent.storage.read(session_id) # type: ignore if agent_session is None: return JSONResponse(status_code=404, content="Session not found.") @@ -340,7 +340,7 @@ def rename_agent_session(agent_id: str, session_id: str, body: AgentRenameReques if agent.storage is None: return JSONResponse(status_code=404, content="Agent does not have storage enabled.") - all_agent_sessions: List[AgentSession] = agent.storage.get_all_sessions(user_id=body.user_id) + all_agent_sessions: List[AgentSession] = agent.storage.get_all_sessions(user_id=body.user_id) # type: ignore for session in all_agent_sessions: if session.session_id == session_id: agent.session_id = session_id @@ -358,7 +358,7 @@ def delete_agent_session(agent_id: str, session_id: str, user_id: Optional[str] if agent.storage is None: return JSONResponse(status_code=404, content="Agent does not have storage enabled.") - all_agent_sessions: List[AgentSession] = agent.storage.get_all_sessions(user_id=user_id) + all_agent_sessions: List[AgentSession] = agent.storage.get_all_sessions(user_id=user_id) # type: ignore for session in all_agent_sessions: if session.session_id == session_id: agent.delete_session(session_id) @@ -436,7 +436,7 @@ def get_all_workflow_sessions(workflow_id: str, user_id: Optional[str] = Query(N try: all_workflow_sessions: List[WorkflowSession] = workflow.storage.get_all_sessions( user_id=user_id, workflow_id=workflow_id - ) + ) # type: ignore except Exception as e: raise HTTPException(status_code=500, detail=f"Error retrieving sessions: {str(e)}") @@ -464,7 +464,7 @@ def get_workflow_session(workflow_id: str, session_id: str, user_id: Optional[st # Retrieve the specific session try: - workflow_session: Optional[WorkflowSession] = workflow.storage.read(session_id, user_id) + workflow_session: Optional[WorkflowSession] = workflow.storage.read(session_id, user_id) # type: ignore except Exception as e: raise HTTPException(status_code=500, detail=f"Error retrieving session: {str(e)}") diff --git a/libs/agno/agno/storage/agent/dynamodb.py b/libs/agno/agno/storage/agent/dynamodb.py index 79fe50883c..286c241f44 100644 --- a/libs/agno/agno/storage/agent/dynamodb.py +++ b/libs/agno/agno/storage/agent/dynamodb.py @@ -1,350 +1 @@ -import time -from dataclasses import asdict -from decimal import Decimal -from typing import Any, Dict, List, Optional - -from agno.storage.agent.base import AgentStorage -from agno.storage.agent.session import AgentSession -from agno.utils.log import logger - -try: - import boto3 - from boto3.dynamodb.conditions import Key - from botocore.exceptions import ClientError -except ImportError: - raise ImportError("`boto3` not installed. Please install using `pip install boto3`.") - - -class DynamoDbAgentStorage(AgentStorage): - def __init__( - self, - table_name: str, - region_name: Optional[str] = None, - aws_access_key_id: Optional[str] = None, - aws_secret_access_key: Optional[str] = None, - endpoint_url: Optional[str] = None, - create_table_if_not_exists: bool = True, - ): - """ - Initialize the DynamoDbAgentStorage. - - Args: - table_name (str): The name of the DynamoDB table. - region_name (Optional[str]): AWS region name. - aws_access_key_id (Optional[str]): AWS access key ID. - aws_secret_access_key (Optional[str]): AWS secret access key. - endpoint_url (Optional[str]): The complete URL to use for the constructed client. - create_table_if_not_exists (bool): Whether to create the table if it does not exist. - """ - self.table_name = table_name - self.region_name = region_name - self.endpoint_url = endpoint_url - self.aws_access_key_id = aws_access_key_id - self.aws_secret_access_key = aws_secret_access_key - self.create_table_if_not_exists = create_table_if_not_exists - - # Initialize DynamoDB resource - self.dynamodb = boto3.resource( - "dynamodb", - region_name=self.region_name, - aws_access_key_id=self.aws_access_key_id, - aws_secret_access_key=self.aws_secret_access_key, - endpoint_url=self.endpoint_url, - ) - - # Initialize table - self.table = self.dynamodb.Table(self.table_name) - - # Optionally create table if it does not exist - if self.create_table_if_not_exists: - self.create() - logger.debug(f"Initialized DynamoDbAgentStorage with table '{self.table_name}'") - - def create(self) -> None: - """ - Create the DynamoDB table if it does not exist. - """ - try: - # Check if table exists - self.dynamodb.meta.client.describe_table(TableName=self.table_name) - logger.debug(f"Table '{self.table_name}' already exists.") - except ClientError as e: - if e.response["Error"]["Code"] == "ResourceNotFoundException": - logger.debug(f"Creating table '{self.table_name}'.") - # Create the table - self.table = self.dynamodb.create_table( - TableName=self.table_name, - KeySchema=[{"AttributeName": "session_id", "KeyType": "HASH"}], - AttributeDefinitions=[ - {"AttributeName": "session_id", "AttributeType": "S"}, - {"AttributeName": "user_id", "AttributeType": "S"}, - {"AttributeName": "agent_id", "AttributeType": "S"}, - {"AttributeName": "created_at", "AttributeType": "N"}, - ], - GlobalSecondaryIndexes=[ - { - "IndexName": "user_id-index", - "KeySchema": [ - {"AttributeName": "user_id", "KeyType": "HASH"}, - {"AttributeName": "created_at", "KeyType": "RANGE"}, - ], - "Projection": {"ProjectionType": "ALL"}, - "ProvisionedThroughput": { - "ReadCapacityUnits": 5, - "WriteCapacityUnits": 5, - }, - }, - { - "IndexName": "agent_id-index", - "KeySchema": [ - {"AttributeName": "agent_id", "KeyType": "HASH"}, - {"AttributeName": "created_at", "KeyType": "RANGE"}, - ], - "Projection": {"ProjectionType": "ALL"}, - "ProvisionedThroughput": { - "ReadCapacityUnits": 5, - "WriteCapacityUnits": 5, - }, - }, - ], - ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, - ) - # Wait until the table exists. - self.table.wait_until_exists() - logger.debug(f"Table '{self.table_name}' created successfully.") - else: - logger.error(f"Unable to create table '{self.table_name}': {e.response['Error']['Message']}") - except Exception as e: - logger.error(f"Exception during table creation: {e}") - - def read(self, session_id: str, user_id: Optional[str] = None) -> Optional[AgentSession]: - """ - Read and return an AgentSession from the database. - - Args: - session_id (str): ID of the session to read. - user_id (Optional[str]): User ID to filter by. Defaults to None. - - Returns: - Optional[AgentSession]: AgentSession object if found, None otherwise. - """ - try: - key = {"session_id": session_id} - if user_id is not None: - key["user_id"] = user_id - - response = self.table.get_item(Key=key) - item = response.get("Item", None) - if item is not None: - # Convert Decimal to int or float - item = self._deserialize_item(item) - return AgentSession.from_dict(item) - except Exception as e: - logger.error(f"Error reading session_id '{session_id}' with user_id '{user_id}': {e}") - return None - - def get_all_session_ids(self, user_id: Optional[str] = None, agent_id: Optional[str] = None) -> List[str]: - """ - Retrieve all session IDs, optionally filtered by user_id and/or agent_id. - - Args: - user_id (Optional[str], optional): User ID to filter by. Defaults to None. - agent_id (Optional[str], optional): Agent ID to filter by. Defaults to None. - - Returns: - List[str]: List of session IDs matching the criteria. - """ - session_ids: List[str] = [] - try: - if user_id is not None: - # Query using user_id index - response = self.table.query( - IndexName="user_id-index", - KeyConditionExpression=Key("user_id").eq(user_id), - ProjectionExpression="session_id", - ) - items = response.get("Items", []) - session_ids.extend([item["session_id"] for item in items if "session_id" in item]) - elif agent_id is not None: - # Query using agent_id index - response = self.table.query( - IndexName="agent_id-index", - KeyConditionExpression=Key("agent_id").eq(agent_id), - ProjectionExpression="session_id", - ) - items = response.get("Items", []) - session_ids.extend([item["session_id"] for item in items if "session_id" in item]) - else: - # Scan the whole table - response = self.table.scan(ProjectionExpression="session_id") - items = response.get("Items", []) - session_ids.extend([item["session_id"] for item in items if "session_id" in item]) - except Exception as e: - logger.error(f"Error retrieving session IDs: {e}") - return session_ids - - def get_all_sessions(self, user_id: Optional[str] = None, agent_id: Optional[str] = None) -> List[AgentSession]: - """ - Retrieve all sessions, optionally filtered by user_id and/or agent_id. - - Args: - user_id (Optional[str], optional): User ID to filter by. Defaults to None. - agent_id (Optional[str], optional): Agent ID to filter by. Defaults to None. - - Returns: - List[AgentSession]: List of AgentSession objects matching the criteria. - """ - sessions: List[AgentSession] = [] - try: - if user_id is not None: - # Query using user_id index - response = self.table.query( - IndexName="user_id-index", - KeyConditionExpression=Key("user_id").eq(user_id), - ProjectionExpression="session_id, agent_id, user_id, memory, agent_data, session_data, extra_data, created_at, updated_at", - ) - items = response.get("Items", []) - for item in items: - item = self._deserialize_item(item) - _agent_session = AgentSession.from_dict(item) - if _agent_session is not None: - sessions.append(_agent_session) - elif agent_id is not None: - # Query using agent_id index - response = self.table.query( - IndexName="agent_id-index", - KeyConditionExpression=Key("agent_id").eq(agent_id), - ProjectionExpression="session_id, agent_id, user_id, memory, agent_data, session_data, extra_data, created_at, updated_at", - ) - items = response.get("Items", []) - for item in items: - item = self._deserialize_item(item) - _agent_session = AgentSession.from_dict(item) - if _agent_session is not None: - sessions.append(_agent_session) - else: - # Scan the whole table - response = self.table.scan( - ProjectionExpression="session_id, agent_id, user_id, memory, agent_data, session_data, extra_data, created_at, updated_at" - ) - items = response.get("Items", []) - for item in items: - item = self._deserialize_item(item) - _agent_session = AgentSession.from_dict(item) - if _agent_session is not None: - sessions.append(_agent_session) - except Exception as e: - logger.error(f"Error retrieving sessions: {e}") - return sessions - - def upsert(self, session: AgentSession) -> Optional[AgentSession]: - """ - Create or update an AgentSession in the database. - - Args: - session (AgentSession): The session data to upsert. - - Returns: - Optional[AgentSession]: The upserted AgentSession, or None if operation failed. - """ - try: - item = asdict(session) - - # Add timestamps - current_time = int(time.time()) - if "created_at" not in item or item["created_at"] is None: - item["created_at"] = current_time - item["updated_at"] = current_time - - # Convert data to DynamoDB compatible format - item = self._serialize_item(item) - - # Put item into DynamoDB - self.table.put_item(Item=item) - return self.read(session.session_id) - except Exception as e: - logger.error(f"Error upserting session: {e}") - return None - - def delete_session(self, session_id: Optional[str] = None): - """ - Delete a session from the database. - - Args: - session_id (Optional[str], optional): ID of the session to delete. Defaults to None. - """ - if session_id is None: - logger.warning("No session_id provided for deletion.") - return - try: - self.table.delete_item(Key={"session_id": session_id}) - logger.info(f"Successfully deleted session with session_id: {session_id}") - except Exception as e: - logger.error(f"Error deleting session: {e}") - - def drop(self) -> None: - """ - Drop the table from the database if it exists. - """ - try: - self.table.delete() - self.table.wait_until_not_exists() - logger.debug(f"Table '{self.table_name}' deleted successfully.") - except Exception as e: - logger.error(f"Error deleting table '{self.table_name}': {e}") - - def upgrade_schema(self) -> None: - """ - Upgrade the schema to the latest version. - This method is currently a placeholder and does not perform any actions. - """ - pass - - def _serialize_item(self, item: Dict[str, Any]) -> Dict[str, Any]: - """ - Serialize item to be compatible with DynamoDB. - - Args: - item (Dict[str, Any]): The item to serialize. - - Returns: - Dict[str, Any]: The serialized item. - """ - - def serialize_value(value): - if isinstance(value, float): - return Decimal(str(value)) - elif isinstance(value, dict): - return {k: serialize_value(v) for k, v in value.items()} - elif isinstance(value, list): - return [serialize_value(v) for v in value] - else: - return value - - return {k: serialize_value(v) for k, v in item.items() if v is not None} - - def _deserialize_item(self, item: Dict[str, Any]) -> Dict[str, Any]: - """ - Deserialize item from DynamoDB format. - - Args: - item (Dict[str, Any]): The item to deserialize. - - Returns: - Dict[str, Any]: The deserialized item. - """ - - def deserialize_value(value): - if isinstance(value, Decimal): - if value % 1 == 0: - return int(value) - else: - return float(value) - elif isinstance(value, dict): - return {k: deserialize_value(v) for k, v in value.items()} - elif isinstance(value, list): - return [deserialize_value(v) for v in value] - else: - return value - - return {k: deserialize_value(v) for k, v in item.items()} +from agno.storage.dynamodb import DynamoDbStorage as DynamoDbAgentStorage # noqa: F401 diff --git a/libs/agno/agno/storage/agent/json.py b/libs/agno/agno/storage/agent/json.py index f2fd8d8765..4c6547aeb4 100644 --- a/libs/agno/agno/storage/agent/json.py +++ b/libs/agno/agno/storage/agent/json.py @@ -1,92 +1 @@ -import json -import time -from dataclasses import asdict -from pathlib import Path -from typing import List, Optional, Union - -from agno.storage.agent.base import AgentStorage -from agno.storage.agent.session import AgentSession -from agno.utils.log import logger - - -class JsonAgentStorage(AgentStorage): - def __init__(self, dir_path: Union[str, Path]): - self.dir_path = Path(dir_path) - self.dir_path.mkdir(parents=True, exist_ok=True) - - def serialize(self, data: dict) -> str: - return json.dumps(data, ensure_ascii=False, indent=4) - - def deserialize(self, data: str) -> dict: - return json.loads(data) - - def create(self) -> None: - """Create the storage if it doesn't exist.""" - if not self.dir_path.exists(): - self.dir_path.mkdir(parents=True, exist_ok=True) - - def read(self, session_id: str, user_id: Optional[str] = None) -> Optional[AgentSession]: - """Read an AgentSession from storage.""" - try: - with open(self.dir_path / f"{session_id}.json", "r", encoding="utf-8") as f: - data = self.deserialize(f.read()) - if user_id and data["user_id"] != user_id: - return None - return AgentSession.from_dict(data) - except FileNotFoundError: - return None - - def get_all_session_ids(self, user_id: Optional[str] = None, agent_id: Optional[str] = None) -> List[str]: - """Get all session IDs, optionally filtered by user_id and/or agent_id.""" - session_ids = [] - for file in self.dir_path.glob("*.json"): - with open(file, "r", encoding="utf-8") as f: - data = self.deserialize(f.read()) - if (not user_id or data["user_id"] == user_id) and (not agent_id or data["agent_id"] == agent_id): - session_ids.append(data["session_id"]) - return session_ids - - def get_all_sessions(self, user_id: Optional[str] = None, agent_id: Optional[str] = None) -> List[AgentSession]: - """Get all sessions, optionally filtered by user_id and/or agent_id.""" - sessions = [] - for file in self.dir_path.glob("*.json"): - with open(file, "r", encoding="utf-8") as f: - data = self.deserialize(f.read()) - if (not user_id or data["user_id"] == user_id) and (not agent_id or data["agent_id"] == agent_id): - _agent_session = AgentSession.from_dict(data) - if _agent_session is not None: - sessions.append(_agent_session) - return sessions - - def upsert(self, session: AgentSession) -> Optional[AgentSession]: - """Insert or update an AgentSession in storage.""" - try: - data = asdict(session) - data["updated_at"] = int(time.time()) - if "created_at" not in data: - data["created_at"] = data["updated_at"] - - with open(self.dir_path / f"{session.session_id}.json", "w", encoding="utf-8") as f: - f.write(self.serialize(data)) - return session - except Exception as e: - logger.error(f"Error upserting session: {e}") - return None - - def delete_session(self, session_id: Optional[str] = None): - """Delete a session from storage.""" - if session_id is None: - return - try: - (self.dir_path / f"{session_id}.json").unlink(missing_ok=True) - except Exception as e: - logger.error(f"Error deleting session: {e}") - - def drop(self) -> None: - """Drop all sessions from storage.""" - for file in self.dir_path.glob("*.json"): - file.unlink() - - def upgrade_schema(self) -> None: - """Upgrade the schema of the storage.""" - pass +from agno.storage.json import JsonStorage as JsonAgentStorage # noqa: F401 diff --git a/libs/agno/agno/storage/agent/mongodb.py b/libs/agno/agno/storage/agent/mongodb.py index 2e1348cac6..5daddf5ec3 100644 --- a/libs/agno/agno/storage/agent/mongodb.py +++ b/libs/agno/agno/storage/agent/mongodb.py @@ -1,228 +1 @@ -from datetime import datetime, timezone -from typing import List, Optional -from uuid import UUID - -try: - from pymongo import MongoClient - from pymongo.collection import Collection - from pymongo.database import Database - from pymongo.errors import PyMongoError -except ImportError: - raise ImportError("`pymongo` not installed. Please install it with `pip install pymongo`") - -from agno.storage.agent.base import AgentStorage -from agno.storage.agent.session import AgentSession -from agno.utils.log import logger - - -class MongoDbAgentStorage(AgentStorage): - def __init__( - self, - collection_name: str, - db_url: Optional[str] = None, - db_name: str = "agno", - client: Optional[MongoClient] = None, - ): - """ - This class provides agent storage using MongoDB. - - Args: - collection_name: Name of the collection to store agent sessions - db_url: MongoDB connection URL - db_name: Name of the database - client: Optional existing MongoDB client - """ - self._client: Optional[MongoClient] = client - if self._client is None and db_url is not None: - self._client = MongoClient(db_url) - elif self._client is None: - self._client = MongoClient() - - if self._client is None: - raise ValueError("Must provide either db_url or client") - - self.collection_name: str = collection_name - self.db_name: str = db_name - self.db: Database = self._client[self.db_name] - self.collection: Collection = self.db[self.collection_name] - - def create(self) -> None: - """Create necessary indexes for the collection""" - try: - # Create indexes - self.collection.create_index("session_id", unique=True) - self.collection.create_index("user_id") - self.collection.create_index("agent_id") - self.collection.create_index("created_at") - except PyMongoError as e: - logger.error(f"Error creating indexes: {e}") - raise - - def read(self, session_id: str, user_id: Optional[str] = None) -> Optional[AgentSession]: - """Read an agent session from MongoDB - Args: - session_id: ID of the session to read - user_id: ID of the user to read - Returns: - AgentSession: The session if found, otherwise None - """ - try: - query = {"session_id": session_id} - if user_id: - query["user_id"] = user_id - - doc = self.collection.find_one(query) - if doc: - # Remove MongoDB _id before converting to AgentSession - doc.pop("_id", None) - return AgentSession.from_dict(doc) - return None - except PyMongoError as e: - logger.error(f"Error reading session: {e}") - return None - - def get_all_session_ids(self, user_id: Optional[str] = None, agent_id: Optional[str] = None) -> List[str]: - """Get all session IDs matching the criteria - Args: - user_id: ID of the user to read - agent_id: ID of the agent to read - Returns: - List[str]: List of session IDs - """ - try: - query = {} - if user_id is not None: - query["user_id"] = user_id - if agent_id is not None: - query["agent_id"] = agent_id - - cursor = self.collection.find(query, {"session_id": 1}).sort("created_at", -1) - - return [str(doc["session_id"]) for doc in cursor] - except PyMongoError as e: - logger.error(f"Error getting session IDs: {e}") - return [] - - def get_all_sessions(self, user_id: Optional[str] = None, agent_id: Optional[str] = None) -> List[AgentSession]: - """Get all sessions matching the criteria - Args: - user_id: ID of the user to read - agent_id: ID of the agent to read - Returns: - List[AgentSession]: List of sessions - """ - try: - query = {} - if user_id is not None: - query["user_id"] = user_id - if agent_id is not None: - query["agent_id"] = agent_id - - cursor = self.collection.find(query).sort("created_at", -1) - sessions = [] - for doc in cursor: - # Remove MongoDB _id before converting to AgentSession - doc.pop("_id", None) - _agent_session = AgentSession.from_dict(doc) - if _agent_session is not None: - sessions.append(_agent_session) - return sessions - except PyMongoError as e: - logger.error(f"Error getting sessions: {e}") - return [] - - def upsert(self, session: AgentSession, create_and_retry: bool = True) -> Optional[AgentSession]: - """Upsert an agent session - Args: - session: AgentSession to upsert - create_and_retry: Whether to create a new session if the session_id already exists - Returns: - AgentSession: The session if upserted, otherwise None - """ - try: - # Convert session to dict and add timestamps - session_dict = session.to_dict() - now = datetime.now(timezone.utc) - timestamp = int(now.timestamp()) - - # Handle UUID serialization - if isinstance(session.session_id, UUID): - session_dict["session_id"] = str(session.session_id) - - # Add version field for optimistic locking - if "_version" not in session_dict: - session_dict["_version"] = 1 - else: - session_dict["_version"] += 1 - - update_data = {**session_dict, "updated_at": timestamp} - - # For new documents, set created_at - query = {"session_id": session_dict["session_id"]} - - doc = self.collection.find_one(query) - if not doc: - update_data["created_at"] = timestamp - - result = self.collection.update_one(query, {"$set": update_data}, upsert=True) - - if result.acknowledged: - return self.read(session_id=session_dict["session_id"]) - return None - - except PyMongoError as e: - logger.error(f"Error upserting session: {e}") - return None - - def delete_session(self, session_id: Optional[str] = None) -> None: - """Delete an agent session - Args: - session_id: ID of the session to delete - Returns: - None - """ - if session_id is None: - logger.warning("No session_id provided for deletion") - return - - try: - result = self.collection.delete_one({"session_id": session_id}) - if result.deleted_count == 0: - logger.debug(f"No session found with session_id: {session_id}") - else: - logger.debug(f"Successfully deleted session with session_id: {session_id}") - except PyMongoError as e: - logger.error(f"Error deleting session: {e}") - - def drop(self) -> None: - """Drop the collection - Returns: - None - """ - try: - self.collection.drop() - except PyMongoError as e: - logger.error(f"Error dropping collection: {e}") - - def upgrade_schema(self) -> None: - """Placeholder for schema upgrades""" - pass - - def __deepcopy__(self, memo): - """Create a deep copy of the MongoDbAgentStorage instance""" - from copy import deepcopy - - # Create a new instance without calling __init__ - cls = self.__class__ - copied_obj = cls.__new__(cls) - memo[id(self)] = copied_obj - - # Deep copy attributes - for k, v in self.__dict__.items(): - if k in {"_client", "db", "collection"}: - # Reuse MongoDB connections without copying - setattr(copied_obj, k, v) - else: - setattr(copied_obj, k, deepcopy(v, memo)) - - return copied_obj +from agno.storage.mongodb import MongoDbStorage as MongoDbAgentStorage # noqa: F401 diff --git a/libs/agno/agno/storage/agent/postgres.py b/libs/agno/agno/storage/agent/postgres.py index 8b8ea257a6..9bed3fcc0f 100644 --- a/libs/agno/agno/storage/agent/postgres.py +++ b/libs/agno/agno/storage/agent/postgres.py @@ -1,367 +1 @@ -import time -from typing import List, Optional - -try: - from sqlalchemy.dialects import postgresql - from sqlalchemy.engine import Engine, create_engine - from sqlalchemy.inspection import inspect - from sqlalchemy.orm import scoped_session, sessionmaker - from sqlalchemy.schema import Column, Index, MetaData, Table - from sqlalchemy.sql.expression import select, text - from sqlalchemy.types import BigInteger, String -except ImportError: - raise ImportError("`sqlalchemy` not installed. Please install it using `pip install sqlalchemy`") - -from agno.storage.agent.base import AgentStorage -from agno.storage.agent.session import AgentSession -from agno.utils.log import logger - - -class PostgresAgentStorage(AgentStorage): - def __init__( - self, - table_name: str, - schema: Optional[str] = "ai", - db_url: Optional[str] = None, - db_engine: Optional[Engine] = None, - schema_version: int = 1, - auto_upgrade_schema: bool = False, - ): - """ - This class provides agent storage using a PostgreSQL table. - - The following order is used to determine the database connection: - 1. Use the db_engine if provided - 2. Use the db_url - 3. Raise an error if neither is provided - - Args: - table_name (str): Name of the table to store Agent sessions. - schema (Optional[str]): The schema to use for the table. Defaults to "ai". - db_url (Optional[str]): The database URL to connect to. - db_engine (Optional[Engine]): The SQLAlchemy database engine to use. - schema_version (int): Version of the schema. Defaults to 1. - auto_upgrade_schema (bool): Whether to automatically upgrade the schema. - - Raises: - ValueError: If neither db_url nor db_engine is provided. - """ - _engine: Optional[Engine] = db_engine - if _engine is None and db_url is not None: - _engine = create_engine(db_url) - - if _engine is None: - raise ValueError("Must provide either db_url or db_engine") - - # Database attributes - self.table_name: str = table_name - self.schema: Optional[str] = schema - self.db_url: Optional[str] = db_url - self.db_engine: Engine = _engine - self.metadata: MetaData = MetaData(schema=self.schema) - self.inspector = inspect(self.db_engine) - - # Table schema version - self.schema_version: int = schema_version - # Automatically upgrade schema if True - self.auto_upgrade_schema: bool = auto_upgrade_schema - - # Database session - self.Session: scoped_session = scoped_session(sessionmaker(bind=self.db_engine)) - # Database table for storage - self.table: Table = self.get_table() - logger.debug(f"Created PostgresAgentStorage: '{self.schema}.{self.table_name}'") - - def get_table_v1(self) -> Table: - """ - Define the table schema for version 1. - - Returns: - Table: SQLAlchemy Table object representing the schema. - """ - table = Table( - self.table_name, - self.metadata, - # Session UUID: Primary Key - Column("session_id", String, primary_key=True), - # ID of the agent that this session is associated with - Column("agent_id", String), - # ID of the user interacting with this agent - Column("user_id", String), - # Agent Memory - Column("memory", postgresql.JSONB), - # Agent Data - Column("agent_data", postgresql.JSONB), - # Session Data - Column("session_data", postgresql.JSONB), - # Extra Data stored with this agent - Column("extra_data", postgresql.JSONB), - # The Unix timestamp of when this session was created. - Column("created_at", BigInteger, server_default=text("(extract(epoch from now()))::bigint")), - # The Unix timestamp of when this session was last updated. - Column("updated_at", BigInteger, server_onupdate=text("(extract(epoch from now()))::bigint")), - extend_existing=True, - ) - - # Add indexes - Index(f"idx_{self.table_name}_session_id", table.c.session_id) - Index(f"idx_{self.table_name}_agent_id", table.c.agent_id) - Index(f"idx_{self.table_name}_user_id", table.c.user_id) - - return table - - def get_table(self) -> Table: - """ - Get the table schema based on the schema version. - - Returns: - Table: SQLAlchemy Table object for the current schema version. - - Raises: - ValueError: If an unsupported schema version is specified. - """ - if self.schema_version == 1: - return self.get_table_v1() - else: - raise ValueError(f"Unsupported schema version: {self.schema_version}") - - def table_exists(self) -> bool: - """ - Check if the table exists in the database. - - Returns: - bool: True if the table exists, False otherwise. - """ - logger.debug(f"Checking if table exists: {self.table.name}") - try: - return self.inspector.has_table(self.table.name, schema=self.schema) - except Exception as e: - logger.error(f"Error checking if table exists: {e}") - return False - - def create(self) -> None: - """ - Create the table if it does not exist. - """ - if not self.table_exists(): - try: - with self.Session() as sess, sess.begin(): - if self.schema is not None: - logger.debug(f"Creating schema: {self.schema}") - sess.execute(text(f"CREATE SCHEMA IF NOT EXISTS {self.schema};")) - logger.debug(f"Creating table: {self.table_name}") - self.table.create(self.db_engine, checkfirst=True) - except Exception as e: - logger.error(f"Could not create table: '{self.table.fullname}': {e}") - - def read(self, session_id: str, user_id: Optional[str] = None) -> Optional[AgentSession]: - """ - Read an AgentSession from the database. - - Args: - session_id (str): ID of the session to read. - user_id (Optional[str]): User ID to filter by. Defaults to None. - - Returns: - Optional[AgentSession]: AgentSession object if found, None otherwise. - """ - try: - with self.Session() as sess: - stmt = select(self.table).where(self.table.c.session_id == session_id) - if user_id: - stmt = stmt.where(self.table.c.user_id == user_id) - result = sess.execute(stmt).fetchone() - return AgentSession.from_dict(result._mapping) if result is not None else None - except Exception as e: - logger.debug(f"Exception reading from table: {e}") - logger.debug(f"Table does not exist: {self.table.name}") - logger.debug("Creating table for future transactions") - self.create() - return None - - def get_all_session_ids(self, user_id: Optional[str] = None, agent_id: Optional[str] = None) -> List[str]: - """ - Get all session IDs, optionally filtered by user_id and/or agent_id. - - Args: - user_id (Optional[str]): The ID of the user to filter by. - agent_id (Optional[str]): The ID of the agent to filter by. - - Returns: - List[str]: List of session IDs matching the criteria. - """ - try: - with self.Session() as sess, sess.begin(): - # get all session_ids - stmt = select(self.table.c.session_id) - if user_id is not None: - stmt = stmt.where(self.table.c.user_id == user_id) - if agent_id is not None: - stmt = stmt.where(self.table.c.agent_id == agent_id) - # order by created_at desc - stmt = stmt.order_by(self.table.c.created_at.desc()) - # execute query - rows = sess.execute(stmt).fetchall() - return [row[0] for row in rows] if rows is not None else [] - except Exception as e: - logger.debug(f"Exception reading from table: {e}") - logger.debug(f"Table does not exist: {self.table.name}") - logger.debug("Creating table for future transactions") - self.create() - return [] - - def get_all_sessions(self, user_id: Optional[str] = None, agent_id: Optional[str] = None) -> List[AgentSession]: - """ - Get all sessions, optionally filtered by user_id and/or agent_id. - - Args: - user_id (Optional[str]): The ID of the user to filter by. - agent_id (Optional[str]): The ID of the agent to filter by. - - Returns: - List[AgentSession]: List of AgentSession objects matching the criteria. - """ - try: - with self.Session() as sess, sess.begin(): - # get all sessions - stmt = select(self.table) - if user_id is not None: - stmt = stmt.where(self.table.c.user_id == user_id) - if agent_id is not None: - stmt = stmt.where(self.table.c.agent_id == agent_id) - # order by created_at desc - stmt = stmt.order_by(self.table.c.created_at.desc()) - # execute query - rows = sess.execute(stmt).fetchall() - return [AgentSession.from_dict(row._mapping) for row in rows] if rows is not None else [] # type: ignore - except Exception as e: - logger.debug(f"Exception reading from table: {e}") - logger.debug(f"Table does not exist: {self.table.name}") - logger.debug("Creating table for future transactions") - self.create() - return [] - - def upsert(self, session: AgentSession, create_and_retry: bool = True) -> Optional[AgentSession]: - """ - Insert or update an AgentSession in the database. - - Args: - session (AgentSession): The session data to upsert. - create_and_retry (bool): Retry upsert if table does not exist. - - Returns: - Optional[AgentSession]: The upserted AgentSession, or None if operation failed. - """ - try: - with self.Session() as sess, sess.begin(): - # Create an insert statement - stmt = postgresql.insert(self.table).values( - session_id=session.session_id, - agent_id=session.agent_id, - user_id=session.user_id, - memory=session.memory, - agent_data=session.agent_data, - session_data=session.session_data, - extra_data=session.extra_data, - ) - - # Define the upsert if the session_id already exists - # See: https://docs.sqlalchemy.org/en/20/dialects/postgresql.html#postgresql-insert-on-conflict - stmt = stmt.on_conflict_do_update( - index_elements=["session_id"], - set_=dict( - agent_id=session.agent_id, - user_id=session.user_id, - memory=session.memory, - agent_data=session.agent_data, - session_data=session.session_data, - extra_data=session.extra_data, - updated_at=int(time.time()), - ), # The updated value for each column - ) - - sess.execute(stmt) - except Exception as e: - logger.debug(f"Exception upserting into table: {e}") - if create_and_retry and not self.table_exists(): - logger.debug(f"Table does not exist: {self.table.name}") - logger.debug("Creating table and retrying upsert") - self.create() - return self.upsert(session, create_and_retry=False) - return None - return self.read(session_id=session.session_id) - - def delete_session(self, session_id: Optional[str] = None): - """ - Delete a session from the database. - - Args: - session_id (Optional[str], optional): ID of the session to delete. Defaults to None. - - Raises: - Exception: If an error occurs during deletion. - """ - if session_id is None: - logger.warning("No session_id provided for deletion.") - return - - try: - with self.Session() as sess, sess.begin(): - # Delete the session with the given session_id - delete_stmt = self.table.delete().where(self.table.c.session_id == session_id) - result = sess.execute(delete_stmt) - if result.rowcount == 0: - logger.debug(f"No session found with session_id: {session_id}") - else: - logger.debug(f"Successfully deleted session with session_id: {session_id}") - except Exception as e: - logger.error(f"Error deleting session: {e}") - - def drop(self) -> None: - """ - Drop the table from the database if it exists. - """ - if self.table_exists(): - logger.debug(f"Deleting table: {self.table_name}") - self.table.drop(self.db_engine) - - def upgrade_schema(self) -> None: - """ - Upgrade the schema to the latest version. - This method is currently a placeholder and does not perform any actions. - """ - pass - - def __deepcopy__(self, memo): - """ - Create a deep copy of the PostgresAgentStorage instance, handling unpickleable attributes. - - Args: - memo (dict): A dictionary of objects already copied during the current copying pass. - - Returns: - PostgresAgentStorage: A deep-copied instance of PostgresAgentStorage. - """ - from copy import deepcopy - - # Create a new instance without calling __init__ - cls = self.__class__ - copied_obj = cls.__new__(cls) - memo[id(self)] = copied_obj - - # Deep copy attributes - for k, v in self.__dict__.items(): - if k in {"metadata", "table", "inspector"}: - continue - # Reuse db_engine and Session without copying - elif k in {"db_engine", "Session"}: - setattr(copied_obj, k, v) - else: - setattr(copied_obj, k, deepcopy(v, memo)) - - # Recreate metadata and table for the copied instance - copied_obj.metadata = MetaData(schema=copied_obj.schema) - copied_obj.inspector = inspect(copied_obj.db_engine) - copied_obj.table = copied_obj.get_table() - - return copied_obj +from agno.storage.postgres import PostgresStorage as PostgresAgentStorage # noqa: F401 diff --git a/libs/agno/agno/storage/agent/singlestore.py b/libs/agno/agno/storage/agent/singlestore.py index fecc18edad..781600f414 100644 --- a/libs/agno/agno/storage/agent/singlestore.py +++ b/libs/agno/agno/storage/agent/singlestore.py @@ -1,303 +1 @@ -import json -from typing import Any, List, Optional - -try: - from sqlalchemy.dialects import mysql - from sqlalchemy.engine import Engine, create_engine - from sqlalchemy.engine.row import Row - from sqlalchemy.inspection import inspect - from sqlalchemy.orm import Session, sessionmaker - from sqlalchemy.schema import Column, MetaData, Table - from sqlalchemy.sql.expression import select, text -except ImportError: - raise ImportError("`sqlalchemy` not installed") - -from agno.storage.agent.base import AgentStorage -from agno.storage.agent.session import AgentSession -from agno.utils.log import logger - - -class SingleStoreAgentStorage(AgentStorage): - def __init__( - self, - table_name: str, - schema: Optional[str] = "ai", - db_url: Optional[str] = None, - db_engine: Optional[Engine] = None, - schema_version: int = 1, - auto_upgrade_schema: bool = False, - ): - """ - This class provides Agent storage using a singlestore table. - - The following order is used to determine the database connection: - 1. Use the db_engine if provided - 2. Use the db_url if provided - - Args: - table_name (str): The name of the table to store the agent data. - schema (Optional[str], optional): The schema of the table. Defaults to "ai". - db_url (Optional[str], optional): The database URL. Defaults to None. - db_engine (Optional[Engine], optional): The database engine. Defaults to None. - schema_version (int, optional): The schema version. Defaults to 1. - auto_upgrade_schema (bool, optional): Automatically upgrade the schema. Defaults to False. - """ - _engine: Optional[Engine] = db_engine - if _engine is None and db_url is not None: - _engine = create_engine(db_url, connect_args={"charset": "utf8mb4"}) - - if _engine is None: - raise ValueError("Must provide either db_url or db_engine") - - # Database attributes - self.table_name: str = table_name - self.schema: Optional[str] = schema - self.db_url: Optional[str] = db_url - self.db_engine: Engine = _engine - self.metadata: MetaData = MetaData(schema=self.schema) - - # Table schema version - self.schema_version: int = schema_version - # Automatically upgrade schema if True - self.auto_upgrade_schema: bool = auto_upgrade_schema - - # Database session - self.Session: sessionmaker[Session] = sessionmaker(bind=self.db_engine) - # Database table for storage - self.table: Table = self.get_table() - - def get_table_v1(self) -> Table: - return Table( - self.table_name, - self.metadata, - # Session UUID: Primary Key - Column("session_id", mysql.TEXT, primary_key=True), - # ID of the agent that this session is associated with - Column("agent_id", mysql.TEXT), - # ID of the user interacting with this agent - Column("user_id", mysql.TEXT), - # Agent memory - Column("memory", mysql.JSON), - # Agent Data - Column("agent_data", mysql.JSON), - # Session Data - Column("session_data", mysql.JSON), - # Extra Data stored with this agent - Column("extra_data", mysql.JSON), - # The Unix timestamp of when this session was created. - Column("created_at", mysql.BIGINT), - # The Unix timestamp of when this session was last updated. - Column("updated_at", mysql.BIGINT), - extend_existing=True, - ) - - def get_table(self) -> Table: - if self.schema_version == 1: - return self.get_table_v1() - else: - raise ValueError(f"Unsupported schema version: {self.schema_version}") - - def table_exists(self) -> bool: - logger.debug(f"Checking if table exists: {self.table.name}") - try: - return inspect(self.db_engine).has_table(self.table.name, schema=self.schema) - except Exception as e: - logger.error(e) - return False - - def create(self) -> None: - if not self.table_exists(): - logger.info(f"\nCreating table: {self.table_name}\n") - self.table.create(self.db_engine) - - def _read(self, session: Session, session_id: str, user_id: Optional[str] = None) -> Optional[Row[Any]]: - stmt = select(self.table).where(self.table.c.session_id == session_id) - if user_id is not None: - stmt = stmt.where(self.table.c.user_id == user_id) - try: - return session.execute(stmt).first() - except Exception as e: - logger.debug(f"Exception reading from table: {e}") - logger.debug(f"Table does not exist: {self.table.name}") - logger.debug(f"Creating table: {self.table_name}") - self.create() - return None - - def read(self, session_id: str, user_id: Optional[str] = None) -> Optional[AgentSession]: - with self.Session.begin() as sess: - existing_row: Optional[Row[Any]] = self._read(session=sess, session_id=session_id, user_id=user_id) - return AgentSession.from_dict(existing_row._mapping) if existing_row is not None else None # type: ignore - - def get_all_session_ids(self, user_id: Optional[str] = None, agent_id: Optional[str] = None) -> List[str]: - session_ids: List[str] = [] - try: - with self.Session.begin() as sess: - # get all session_ids for this user - stmt = select(self.table) - if user_id is not None: - stmt = stmt.where(self.table.c.user_id == user_id) - if agent_id is not None: - stmt = stmt.where(self.table.c.agent_id == agent_id) - # order by created_at desc - stmt = stmt.order_by(self.table.c.created_at.desc()) - # execute query - rows = sess.execute(stmt).fetchall() - for row in rows: - if row is not None and row.session_id is not None: - session_ids.append(row.session_id) - except Exception as e: - logger.error(f"An unexpected error occurred: {str(e)}") - return session_ids - - def get_all_sessions(self, user_id: Optional[str] = None, agent_id: Optional[str] = None) -> List[AgentSession]: - sessions: List[AgentSession] = [] - try: - with self.Session.begin() as sess: - # get all sessions for this user - stmt = select(self.table) - if user_id is not None: - stmt = stmt.where(self.table.c.user_id == user_id) - if agent_id is not None: - stmt = stmt.where(self.table.c.agent_id == agent_id) - # order by created_at desc - stmt = stmt.order_by(self.table.c.created_at.desc()) - # execute query - rows = sess.execute(stmt).fetchall() - for row in rows: - if row.session_id is not None: - _agent_session = AgentSession.from_dict(row._mapping) # type: ignore - if _agent_session is not None: - sessions.append(_agent_session) - except Exception: - logger.debug(f"Table does not exist: {self.table.name}") - return sessions - - def upsert(self, session: AgentSession) -> Optional[AgentSession]: - """ - Create a new session if it does not exist, otherwise update the existing session. - """ - - with self.Session.begin() as sess: - # Create an insert statement using MySQL's ON DUPLICATE KEY UPDATE syntax - upsert_sql = text( - f""" - INSERT INTO {self.schema}.{self.table_name} - (session_id, agent_id, user_id, memory, agent_data, session_data, extra_data, created_at, updated_at) - VALUES - (:session_id, :agent_id, :user_id, :memory, :agent_data, :session_data, :extra_data, UNIX_TIMESTAMP(), NULL) - ON DUPLICATE KEY UPDATE - agent_id = VALUES(agent_id), - user_id = VALUES(user_id), - memory = VALUES(memory), - agent_data = VALUES(agent_data), - session_data = VALUES(session_data), - extra_data = VALUES(extra_data), - updated_at = UNIX_TIMESTAMP(); - """ - ) - - try: - sess.execute( - upsert_sql, - { - "session_id": session.session_id, - "agent_id": session.agent_id, - "user_id": session.user_id, - "memory": json.dumps(session.memory, ensure_ascii=False) - if session.memory is not None - else None, - "agent_data": json.dumps(session.agent_data, ensure_ascii=False) - if session.agent_data is not None - else None, - "session_data": json.dumps(session.session_data, ensure_ascii=False) - if session.session_data is not None - else None, - "extra_data": json.dumps(session.extra_data, ensure_ascii=False) - if session.extra_data is not None - else None, - }, - ) - except Exception: - # Create table and try again - self.create() - sess.execute( - upsert_sql, - { - "session_id": session.session_id, - "agent_id": session.agent_id, - "user_id": session.user_id, - "memory": json.dumps(session.memory, ensure_ascii=False) - if session.memory is not None - else None, - "agent_data": json.dumps(session.agent_data, ensure_ascii=False) - if session.agent_data is not None - else None, - "session_data": json.dumps(session.session_data, ensure_ascii=False) - if session.session_data is not None - else None, - "extra_data": json.dumps(session.extra_data, ensure_ascii=False) - if session.extra_data is not None - else None, - }, - ) - return self.read(session_id=session.session_id) - - def delete_session(self, session_id: Optional[str] = None): - if session_id is None: - logger.warning("No session_id provided for deletion.") - return - - with self.Session() as sess, sess.begin(): - try: - # Delete the session with the given session_id - delete_stmt = self.table.delete().where(self.table.c.session_id == session_id) - result = sess.execute(delete_stmt) - - if result.rowcount == 0: - logger.warning(f"No session found with session_id: {session_id}") - else: - logger.info(f"Successfully deleted session with session_id: {session_id}") - except Exception as e: - logger.error(f"Error deleting session: {e}") - raise - - def drop(self) -> None: - if self.table_exists(): - logger.info(f"Deleting table: {self.table_name}") - self.table.drop(self.db_engine) - - def upgrade_schema(self) -> None: - pass - - def __deepcopy__(self, memo): - """ - Create a deep copy of the SingleStoreAgentStorage instance, handling unpickleable attributes. - - Args: - memo (dict): A dictionary of objects already copied during the current copying pass. - - Returns: - SingleStoreAgentStorage: A deep-copied instance of SingleStoreAgentStorage. - """ - from copy import deepcopy - - # Create a new instance without calling __init__ - cls = self.__class__ - copied_obj = cls.__new__(cls) - memo[id(self)] = copied_obj - - # Deep copy attributes - for k, v in self.__dict__.items(): - if k in {"metadata", "table"}: - continue - # Reuse db_engine and Session without copying - elif k in {"db_engine", "Session"}: - setattr(copied_obj, k, v) - else: - setattr(copied_obj, k, deepcopy(v, memo)) - - # Recreate metadata and table for the copied instance - copied_obj.metadata = MetaData(schema=self.schema) - copied_obj.table = copied_obj.get_table() - - return copied_obj +from agno.storage.singlestore import SingleStoreStorage as SingleStoreAgentStorage # noqa: F401 diff --git a/libs/agno/agno/storage/agent/sqlite.py b/libs/agno/agno/storage/agent/sqlite.py index 64f43ad169..d47dbce956 100644 --- a/libs/agno/agno/storage/agent/sqlite.py +++ b/libs/agno/agno/storage/agent/sqlite.py @@ -1,357 +1 @@ -import time -from pathlib import Path -from typing import List, Optional - -try: - from sqlalchemy.dialects import sqlite - from sqlalchemy.engine import Engine, create_engine - from sqlalchemy.inspection import inspect - from sqlalchemy.orm import Session, sessionmaker - from sqlalchemy.schema import Column, MetaData, Table - from sqlalchemy.sql.expression import select - from sqlalchemy.types import String -except ImportError: - raise ImportError("`sqlalchemy` not installed. Please install it using `pip install sqlalchemy`") - -from agno.storage.agent.base import AgentStorage -from agno.storage.agent.session import AgentSession -from agno.utils.log import logger - - -class SqliteAgentStorage(AgentStorage): - def __init__( - self, - table_name: str, - db_url: Optional[str] = None, - db_file: Optional[str] = None, - db_engine: Optional[Engine] = None, - schema_version: int = 1, - auto_upgrade_schema: bool = False, - ): - """ - This class provides agent storage using a sqlite database. - - The following order is used to determine the database connection: - 1. Use the db_engine if provided - 2. Use the db_url - 3. Use the db_file - 4. Create a new in-memory database - - Args: - table_name: The name of the table to store Agent sessions. - db_url: The database URL to connect to. - db_file: The database file to connect to. - db_engine: The SQLAlchemy database engine to use. - """ - _engine: Optional[Engine] = db_engine - if _engine is None and db_url is not None: - _engine = create_engine(db_url) - elif _engine is None and db_file is not None: - # Use the db_file to create the engine - db_path = Path(db_file).resolve() - # Ensure the directory exists - db_path.parent.mkdir(parents=True, exist_ok=True) - _engine = create_engine(f"sqlite:///{db_path}") - else: - _engine = create_engine("sqlite://") - - if _engine is None: - raise ValueError("Must provide either db_url, db_file or db_engine") - - # Database attributes - self.table_name: str = table_name - self.db_url: Optional[str] = db_url - self.db_engine: Engine = _engine - self.metadata: MetaData = MetaData() - self.inspector = inspect(self.db_engine) - - # Table schema version - self.schema_version: int = schema_version - # Automatically upgrade schema if True - self.auto_upgrade_schema: bool = auto_upgrade_schema - - # Database session - self.Session: sessionmaker[Session] = sessionmaker(bind=self.db_engine) - # Database table for storage - self.table: Table = self.get_table() - - def get_table_v1(self) -> Table: - """ - Define the table schema for version 1. - - Returns: - Table: SQLAlchemy Table object representing the schema. - """ - return Table( - self.table_name, - self.metadata, - # Session UUID: Primary Key - Column("session_id", String, primary_key=True), - # ID of the agent that this session is associated with - Column("agent_id", String), - # ID of the user interacting with this agent - Column("user_id", String), - # Agent Memory - Column("memory", sqlite.JSON), - # Agent Data - Column("agent_data", sqlite.JSON), - # Session Data - Column("session_data", sqlite.JSON), - # Extra Data stored with this agent - Column("extra_data", sqlite.JSON), - # The Unix timestamp of when this session was created. - Column("created_at", sqlite.INTEGER, default=lambda: int(time.time())), - # The Unix timestamp of when this session was last updated. - Column("updated_at", sqlite.INTEGER, onupdate=lambda: int(time.time())), - extend_existing=True, - sqlite_autoincrement=True, - ) - - def get_table(self) -> Table: - """ - Get the table schema based on the schema version. - - Returns: - Table: SQLAlchemy Table object for the current schema version. - - Raises: - ValueError: If an unsupported schema version is specified. - """ - if self.schema_version == 1: - return self.get_table_v1() - else: - raise ValueError(f"Unsupported schema version: {self.schema_version}") - - def table_exists(self) -> bool: - """ - Check if the table exists in the database. - - Returns: - bool: True if the table exists, False otherwise. - """ - logger.debug(f"Checking if table exists: {self.table.name}") - try: - return self.inspector.has_table(self.table.name) - except Exception as e: - logger.error(f"Error checking if table exists: {e}") - return False - - def create(self) -> None: - """ - Create the table if it doesn't exist. - """ - if not self.table_exists(): - logger.debug(f"Creating table: {self.table.name}") - self.table.create(self.db_engine, checkfirst=True) - - def read(self, session_id: str, user_id: Optional[str] = None) -> Optional[AgentSession]: - """ - Read an AgentSession from the database. - - Args: - session_id (str): ID of the session to read. - user_id (Optional[str]): User ID to filter by. Defaults to None. - - Returns: - Optional[AgentSession]: AgentSession object if found, None otherwise. - """ - try: - with self.Session() as sess: - stmt = select(self.table).where(self.table.c.session_id == session_id) - if user_id: - stmt = stmt.where(self.table.c.user_id == user_id) - result = sess.execute(stmt).fetchone() - return AgentSession.from_dict(result._mapping) if result is not None else None # type: ignore - except Exception as e: - logger.debug(f"Exception reading from table: {e}") - logger.debug(f"Table does not exist: {self.table.name}") - logger.debug("Creating table for future transactions") - self.create() - return None - - def get_all_session_ids(self, user_id: Optional[str] = None, agent_id: Optional[str] = None) -> List[str]: - """ - Get all session IDs, optionally filtered by user_id and/or agent_id. - - Args: - user_id (Optional[str]): The ID of the user to filter by. - agent_id (Optional[str]): The ID of the agent to filter by. - - Returns: - List[str]: List of session IDs matching the criteria. - """ - try: - with self.Session() as sess, sess.begin(): - # get all session_ids - stmt = select(self.table.c.session_id) - if user_id is not None: - stmt = stmt.where(self.table.c.user_id == user_id) - if agent_id is not None: - stmt = stmt.where(self.table.c.agent_id == agent_id) - # order by created_at desc - stmt = stmt.order_by(self.table.c.created_at.desc()) - # execute query - rows = sess.execute(stmt).fetchall() - return [row[0] for row in rows] if rows is not None else [] - except Exception as e: - logger.debug(f"Exception reading from table: {e}") - logger.debug(f"Table does not exist: {self.table.name}") - logger.debug("Creating table for future transactions") - self.create() - return [] - - def get_all_sessions(self, user_id: Optional[str] = None, agent_id: Optional[str] = None) -> List[AgentSession]: - """ - Get all sessions, optionally filtered by user_id and/or agent_id. - - Args: - user_id (Optional[str]): The ID of the user to filter by. - agent_id (Optional[str]): The ID of the agent to filter by. - - Returns: - List[AgentSession]: List of AgentSession objects matching the criteria. - """ - try: - with self.Session() as sess, sess.begin(): - # get all sessions - stmt = select(self.table) - if user_id is not None: - stmt = stmt.where(self.table.c.user_id == user_id) - if agent_id is not None: - stmt = stmt.where(self.table.c.agent_id == agent_id) - # order by created_at desc - stmt = stmt.order_by(self.table.c.created_at.desc()) - # execute query - rows = sess.execute(stmt).fetchall() - return [AgentSession.from_dict(row._mapping) for row in rows] if rows is not None else [] # type: ignore - except Exception as e: - logger.debug(f"Exception reading from table: {e}") - logger.debug(f"Table does not exist: {self.table.name}") - logger.debug("Creating table for future transactions") - self.create() - return [] - - def upsert(self, session: AgentSession, create_and_retry: bool = True) -> Optional[AgentSession]: - """ - Insert or update an AgentSession in the database. - - Args: - session (AgentSession): The session data to upsert. - create_and_retry (bool): Retry upsert if table does not exist. - - Returns: - Optional[AgentSession]: The upserted AgentSession, or None if operation failed. - """ - try: - with self.Session() as sess, sess.begin(): - # Create an insert statement - stmt = sqlite.insert(self.table).values( - session_id=session.session_id, - agent_id=session.agent_id, - user_id=session.user_id, - memory=session.memory, - agent_data=session.agent_data, - session_data=session.session_data, - extra_data=session.extra_data, - ) - - # Define the upsert if the session_id already exists - # See: https://docs.sqlalchemy.org/en/20/dialects/sqlite.html#insert-on-conflict-upsert - stmt = stmt.on_conflict_do_update( - index_elements=["session_id"], - set_=dict( - agent_id=session.agent_id, - user_id=session.user_id, - memory=session.memory, - agent_data=session.agent_data, - session_data=session.session_data, - extra_data=session.extra_data, - updated_at=int(time.time()), - ), # The updated value for each column - ) - - sess.execute(stmt) - except Exception as e: - logger.debug(f"Exception upserting into table: {e}") - if create_and_retry and not self.table_exists(): - logger.debug(f"Table does not exist: {self.table.name}") - logger.debug("Creating table and retrying upsert") - self.create() - return self.upsert(session, create_and_retry=False) - return None - return self.read(session_id=session.session_id) - - def delete_session(self, session_id: Optional[str] = None): - """ - Delete a workflow session from the database. - - Args: - session_id (Optional[str]): The ID of the session to delete. - - Raises: - ValueError: If session_id is not provided. - """ - if session_id is None: - logger.warning("No session_id provided for deletion.") - return - - try: - with self.Session() as sess, sess.begin(): - # Delete the session with the given session_id - delete_stmt = self.table.delete().where(self.table.c.session_id == session_id) - result = sess.execute(delete_stmt) - if result.rowcount == 0: - logger.debug(f"No session found with session_id: {session_id}") - else: - logger.debug(f"Successfully deleted session with session_id: {session_id}") - except Exception as e: - logger.error(f"Error deleting session: {e}") - - def drop(self) -> None: - """ - Drop the table from the database if it exists. - """ - if self.table_exists(): - logger.debug(f"Deleting table: {self.table_name}") - self.table.drop(self.db_engine) - - def upgrade_schema(self) -> None: - """ - Upgrade the schema of the workflow storage table. - This method is currently a placeholder and does not perform any actions. - """ - pass - - def __deepcopy__(self, memo): - """ - Create a deep copy of the SqliteAgentStorage instance, handling unpickleable attributes. - - Args: - memo (dict): A dictionary of objects already copied during the current copying pass. - - Returns: - SqliteAgentStorage: A deep-copied instance of SqliteAgentStorage. - """ - from copy import deepcopy - - # Create a new instance without calling __init__ - cls = self.__class__ - copied_obj = cls.__new__(cls) - memo[id(self)] = copied_obj - - # Deep copy attributes - for k, v in self.__dict__.items(): - if k in {"metadata", "table", "inspector"}: - continue - # Reuse db_engine and Session without copying - elif k in {"db_engine", "Session"}: - setattr(copied_obj, k, v) - else: - setattr(copied_obj, k, deepcopy(v, memo)) - - # Recreate metadata and table for the copied instance - copied_obj.metadata = MetaData() - copied_obj.inspector = inspect(copied_obj.db_engine) - copied_obj.table = copied_obj.get_table() - - return copied_obj +from agno.storage.sqlite import SqliteStorage as SqliteAgentStorage # noqa: F401 diff --git a/libs/agno/agno/storage/agent/yaml.py b/libs/agno/agno/storage/agent/yaml.py index 2af24e5ba6..b2673d6750 100644 --- a/libs/agno/agno/storage/agent/yaml.py +++ b/libs/agno/agno/storage/agent/yaml.py @@ -1,93 +1 @@ -import time -from dataclasses import asdict -from pathlib import Path -from typing import List, Optional, Union - -import yaml - -from agno.storage.agent.base import AgentStorage -from agno.storage.agent.session import AgentSession -from agno.utils.log import logger - - -class YamlAgentStorage(AgentStorage): - def __init__(self, dir_path: Union[str, Path]): - self.dir_path = Path(dir_path) - self.dir_path.mkdir(parents=True, exist_ok=True) - - def serialize(self, data: dict) -> str: - return yaml.dump(data, default_flow_style=False) - - def deserialize(self, data: str) -> dict: - return yaml.safe_load(data) - - def create(self) -> None: - """Create the storage if it doesn't exist.""" - if not self.dir_path.exists(): - self.dir_path.mkdir(parents=True, exist_ok=True) - - def read(self, session_id: str, user_id: Optional[str] = None) -> Optional[AgentSession]: - """Read an AgentSession from storage.""" - try: - with open(self.dir_path / f"{session_id}.yaml", "r", encoding="utf-8") as f: - data = self.deserialize(f.read()) - if user_id and data["user_id"] != user_id: - return None - return AgentSession.from_dict(data) - except FileNotFoundError: - return None - - def get_all_session_ids(self, user_id: Optional[str] = None, agent_id: Optional[str] = None) -> List[str]: - """Get all session IDs, optionally filtered by user_id and/or agent_id.""" - session_ids = [] - for file in self.dir_path.glob("*.yaml"): - with open(file, "r", encoding="utf-8") as f: - data = self.deserialize(f.read()) - if (not user_id or data["user_id"] == user_id) and (not agent_id or data["agent_id"] == agent_id): - session_ids.append(data["session_id"]) - return session_ids - - def get_all_sessions(self, user_id: Optional[str] = None, agent_id: Optional[str] = None) -> List[AgentSession]: - """Get all sessions, optionally filtered by user_id and/or agent_id.""" - sessions = [] - for file in self.dir_path.glob("*.yaml"): - with open(file, "r", encoding="utf-8") as f: - data = self.deserialize(f.read()) - if (not user_id or data["user_id"] == user_id) and (not agent_id or data["agent_id"] == agent_id): - _agent_session = AgentSession.from_dict(data) - if _agent_session is not None: - sessions.append(_agent_session) - return sessions - - def upsert(self, session: AgentSession) -> Optional[AgentSession]: - """Insert or update an AgentSession in storage.""" - try: - data = asdict(session) - data["updated_at"] = int(time.time()) - if "created_at" not in data: - data["created_at"] = data["updated_at"] - - with open(self.dir_path / f"{session.session_id}.yaml", "w", encoding="utf-8") as f: - f.write(self.serialize(data)) - return session - except Exception as e: - logger.error(f"Error upserting session: {e}") - return None - - def delete_session(self, session_id: Optional[str] = None): - """Delete a session from storage.""" - if session_id is None: - return - try: - (self.dir_path / f"{session_id}.yaml").unlink(missing_ok=True) - except Exception as e: - logger.error(f"Error deleting session: {e}") - - def drop(self) -> None: - """Drop all sessions from storage.""" - for file in self.dir_path.glob("*.yaml"): - file.unlink() - - def upgrade_schema(self) -> None: - """Upgrade the schema of the storage.""" - pass +from agno.storage.yaml import YamlStorage as YamlAgentStorage # noqa: F401 diff --git a/libs/agno/agno/storage/agent/base.py b/libs/agno/agno/storage/base.py similarity index 54% rename from libs/agno/agno/storage/agent/base.py rename to libs/agno/agno/storage/base.py index 03fd2aa56c..8567b0a22e 100644 --- a/libs/agno/agno/storage/agent/base.py +++ b/libs/agno/agno/storage/base.py @@ -1,16 +1,29 @@ from abc import ABC, abstractmethod -from typing import List, Optional +from typing import List, Literal, Optional -from agno.storage.agent.session import AgentSession +from agno.storage.session import Session -class AgentStorage(ABC): +class Storage(ABC): + def __init__(self, mode: Optional[Literal["agent", "workflow"]] = "agent"): + self._mode: Literal["agent", "workflow"] = "agent" if mode is None else mode + + @property + def mode(self) -> Literal["agent", "workflow"]: + """Get the mode of the storage.""" + return self._mode + + @mode.setter + def mode(self, value: Optional[Literal["agent", "workflow"]]) -> None: + """Set the mode of the storage.""" + self._mode = "agent" if value is None else value + @abstractmethod def create(self) -> None: raise NotImplementedError @abstractmethod - def read(self, session_id: str, user_id: Optional[str] = None) -> Optional[AgentSession]: + def read(self, session_id: str, user_id: Optional[str] = None) -> Optional[Session]: raise NotImplementedError @abstractmethod @@ -18,11 +31,11 @@ def get_all_session_ids(self, user_id: Optional[str] = None, agent_id: Optional[ raise NotImplementedError @abstractmethod - def get_all_sessions(self, user_id: Optional[str] = None, agent_id: Optional[str] = None) -> List[AgentSession]: + def get_all_sessions(self, user_id: Optional[str] = None, agent_id: Optional[str] = None) -> List[Session]: raise NotImplementedError @abstractmethod - def upsert(self, session: AgentSession) -> Optional[AgentSession]: + def upsert(self, session: Session) -> Optional[Session]: raise NotImplementedError @abstractmethod diff --git a/libs/agno/agno/storage/dynamodb.py b/libs/agno/agno/storage/dynamodb.py new file mode 100644 index 0000000000..3f93c0e587 --- /dev/null +++ b/libs/agno/agno/storage/dynamodb.py @@ -0,0 +1,436 @@ +import time +from dataclasses import asdict +from decimal import Decimal +from typing import Any, Dict, List, Literal, Optional + +from agno.storage.base import Storage +from agno.storage.session import Session +from agno.storage.session.agent import AgentSession +from agno.storage.session.workflow import WorkflowSession +from agno.utils.log import logger + +try: + import boto3 + from boto3.dynamodb.conditions import Key + from botocore.exceptions import ClientError +except ImportError: + raise ImportError("`boto3` not installed. Please install using `pip install boto3`.") + + +class DynamoDbStorage(Storage): + def __init__( + self, + table_name: str, + region_name: Optional[str] = None, + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + endpoint_url: Optional[str] = None, + create_table_if_not_exists: bool = True, + mode: Optional[Literal["agent", "workflow"]] = "agent", + ): + """ + Initialize the DynamoDbStorage. + + Args: + table_name (str): The name of the DynamoDB table. + region_name (Optional[str]): AWS region name. + aws_access_key_id (Optional[str]): AWS access key ID. + aws_secret_access_key (Optional[str]): AWS secret access key. + endpoint_url (Optional[str]): The complete URL to use for the constructed client. + create_table_if_not_exists (bool): Whether to create the table if it does not exist. + mode (Optional[Literal["agent", "workflow"]]): The mode of the storage. + """ + super().__init__(mode) + self.table_name = table_name + self.region_name = region_name + self.endpoint_url = endpoint_url + self.aws_access_key_id = aws_access_key_id + self.aws_secret_access_key = aws_secret_access_key + self.create_table_if_not_exists = create_table_if_not_exists + + # Initialize DynamoDB resource + self.dynamodb = boto3.resource( + "dynamodb", + region_name=self.region_name, + aws_access_key_id=self.aws_access_key_id, + aws_secret_access_key=self.aws_secret_access_key, + endpoint_url=self.endpoint_url, + ) + + # Initialize table + self.table = self.dynamodb.Table(self.table_name) + + # Optionally create table if it does not exist + if self.create_table_if_not_exists: + self.create() + logger.debug(f"Initialized DynamoDbStorage with table '{self.table_name}'") + + @property + def mode(self) -> Literal["agent", "workflow"]: + """Get the mode of the storage.""" + return super().mode + + @mode.setter + def mode(self, value: Optional[Literal["agent", "workflow"]]) -> None: + """Set the mode and refresh the table if mode changes.""" + super(DynamoDbStorage, type(self)).mode.fset(self, value) # type: ignore + if value is not None: + if self.create_table_if_not_exists: + self.create() + + def create(self) -> None: + """ + Create the DynamoDB table if it does not exist. + """ + try: + # Check if table exists + self.dynamodb.meta.client.describe_table(TableName=self.table_name) + logger.debug(f"Table '{self.table_name}' already exists.") + except ClientError as e: + if e.response["Error"]["Code"] == "ResourceNotFoundException": + logger.debug(f"Creating table '{self.table_name}'.") + + if self.mode == "agent": + attribute_definitions = [ + {"AttributeName": "session_id", "AttributeType": "S"}, + {"AttributeName": "user_id", "AttributeType": "S"}, + {"AttributeName": "agent_id", "AttributeType": "S"}, + {"AttributeName": "created_at", "AttributeType": "N"}, + ] + else: + attribute_definitions = [ + {"AttributeName": "session_id", "AttributeType": "S"}, + {"AttributeName": "user_id", "AttributeType": "S"}, + {"AttributeName": "workflow_id", "AttributeType": "S"}, + {"AttributeName": "created_at", "AttributeType": "N"}, + ] + + secondary_indexes = [ + { + "IndexName": "user_id-index", + "KeySchema": [ + {"AttributeName": "user_id", "KeyType": "HASH"}, + {"AttributeName": "created_at", "KeyType": "RANGE"}, + ], + "Projection": {"ProjectionType": "ALL"}, + "ProvisionedThroughput": { + "ReadCapacityUnits": 5, + "WriteCapacityUnits": 5, + }, + } + ] + if self.mode == "agent": + secondary_indexes.append( + { + "IndexName": "agent_id-index", + "KeySchema": [ + {"AttributeName": "agent_id", "KeyType": "HASH"}, + {"AttributeName": "created_at", "KeyType": "RANGE"}, + ], + "Projection": {"ProjectionType": "ALL"}, + "ProvisionedThroughput": { + "ReadCapacityUnits": 5, + "WriteCapacityUnits": 5, + }, + } + ) + else: + secondary_indexes.append( + { + "IndexName": "workflow_id-index", + "KeySchema": [ + {"AttributeName": "workflow_id", "KeyType": "HASH"}, + {"AttributeName": "created_at", "KeyType": "RANGE"}, + ], + "Projection": {"ProjectionType": "ALL"}, + "ProvisionedThroughput": { + "ReadCapacityUnits": 5, + "WriteCapacityUnits": 5, + }, + } + ) + + # Create the table + self.table = self.dynamodb.create_table( + TableName=self.table_name, + KeySchema=[{"AttributeName": "session_id", "KeyType": "HASH"}], + AttributeDefinitions=attribute_definitions, + GlobalSecondaryIndexes=secondary_indexes, + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) + # Wait until the table exists. + self.table.wait_until_exists() + logger.debug(f"Table '{self.table_name}' created successfully.") + else: + logger.error(f"Unable to create table '{self.table_name}': {e.response['Error']['Message']}") + except Exception as e: + logger.error(f"Exception during table creation: {e}") + + def read(self, session_id: str, user_id: Optional[str] = None) -> Optional[Session]: + """ + Read and return a Session from the database. + + Args: + session_id (str): ID of the session to read. + user_id (Optional[str]): User ID to filter by. Defaults to None. + + Returns: + Optional[Session]: Session object if found, None otherwise. + """ + try: + key = {"session_id": session_id} + if user_id is not None: + key["user_id"] = user_id + + response = self.table.get_item(Key=key) + item = response.get("Item", None) + if item is not None: + # Convert Decimal to int or float + item = self._deserialize_item(item) + if self.mode == "agent": + return AgentSession.from_dict(item) + elif self.mode == "workflow": + return WorkflowSession.from_dict(item) + except Exception as e: + logger.error(f"Error reading session_id '{session_id}' with user_id '{user_id}': {e}") + return None + + def get_all_session_ids(self, user_id: Optional[str] = None, entity_id: Optional[str] = None) -> List[str]: + """ + Retrieve all session IDs, optionally filtered by user_id and/or entity_id. + + Args: + user_id (Optional[str], optional): User ID to filter by. Defaults to None. + entity_id (Optional[str], optional): Entity ID to filter by. Defaults to None. + + Returns: + List[str]: List of session IDs matching the criteria. + """ + session_ids: List[str] = [] + try: + if user_id is not None: + # Query using user_id index + response = self.table.query( + IndexName="user_id-index", + KeyConditionExpression=Key("user_id").eq(user_id), + ProjectionExpression="session_id", + ) + items = response.get("Items", []) + session_ids.extend([item["session_id"] for item in items if "session_id" in item]) + elif entity_id is not None: + if self.mode == "agent": + # Query using agent_id index + response = self.table.query( + IndexName="agent_id-index", + KeyConditionExpression=Key("agent_id").eq(entity_id), + ProjectionExpression="session_id", + ) + else: + # Query using workflow_id index + response = self.table.query( + IndexName="workflow_id-index", + KeyConditionExpression=Key("workflow_id").eq(entity_id), + ProjectionExpression="session_id", + ) + items = response.get("Items", []) + session_ids.extend([item["session_id"] for item in items if "session_id" in item]) + else: + # Scan the whole table + response = self.table.scan(ProjectionExpression="session_id") + items = response.get("Items", []) + session_ids.extend([item["session_id"] for item in items if "session_id" in item]) + except Exception as e: + logger.error(f"Error retrieving session IDs: {e}") + return session_ids + + def get_all_sessions(self, user_id: Optional[str] = None, entity_id: Optional[str] = None) -> List[Session]: + """ + Retrieve all sessions, optionally filtered by user_id and/or entity_id. + + Args: + user_id (Optional[str], optional): User ID to filter by. Defaults to None. + entity_id (Optional[str], optional): Entity ID to filter by. Defaults to None. + + Returns: + List[Session]: List of AgentSession or WorkflowSession objects matching the criteria. + """ + sessions: List[Session] = [] + try: + if user_id is not None: + if self.mode == "agent": + # Query using user_id index + response = self.table.query( + IndexName="user_id-index", + KeyConditionExpression=Key("user_id").eq(user_id), + ProjectionExpression="session_id, agent_id, user_id, memory, agent_data, session_data, extra_data, created_at, updated_at", + ) + else: + # Query using user_id index + response = self.table.query( + IndexName="user_id-index", + KeyConditionExpression=Key("user_id").eq(user_id), + ProjectionExpression="session_id, workflow_id, user_id, memory, workflow_data, session_data, extra_data, created_at, updated_at", + ) + items = response.get("Items", []) + for item in items: + item = self._deserialize_item(item) + _session: Optional[Session] = None + if self.mode == "agent": + _session = AgentSession.from_dict(item) + else: + _session = WorkflowSession.from_dict(item) # type: ignore + if _session is not None: + sessions.append(_session) + elif entity_id is not None: + if self.mode == "agent": + # Query using agent_id index + response = self.table.query( + IndexName="agent_id-index", + KeyConditionExpression=Key("agent_id").eq(entity_id), + ProjectionExpression="session_id, agent_id, user_id, memory, agent_data, session_data, extra_data, created_at, updated_at", + ) + else: + # Query using workflow_id index + response = self.table.query( + IndexName="workflow_id-index", + KeyConditionExpression=Key("workflow_id").eq(entity_id), + ProjectionExpression="session_id, workflow_id, user_id, memory, workflow_data, session_data, extra_data, created_at, updated_at", + ) + items = response.get("Items", []) + for item in items: + item = self._deserialize_item(item) + if self.mode == "agent": + _session = AgentSession.from_dict(item) # type: ignore + else: + _session = WorkflowSession.from_dict(item) # type: ignore + if _session is not None: + sessions.append(_session) + else: + # Scan the whole table + response = self.table.scan( + ProjectionExpression="session_id, agent_id, user_id, memory, agent_data, session_data, extra_data, created_at, updated_at" + ) + items = response.get("Items", []) + for item in items: + item = self._deserialize_item(item) + if self.mode == "agent": + _session = AgentSession.from_dict(item) # type: ignore + else: + _session = WorkflowSession.from_dict(item) # type: ignore + if _session is not None: + sessions.append(_session) + except Exception as e: + logger.error(f"Error retrieving sessions: {e}") + return sessions + + def upsert(self, session: Session) -> Optional[Session]: + """ + Create or update a Session in the database. + + Args: + session (Session): The session data to upsert. + + Returns: + Optional[Session]: The upserted Session, or None if operation failed. + """ + try: + item = asdict(session) + + # Add timestamps + current_time = int(time.time()) + if "created_at" not in item or item["created_at"] is None: + item["created_at"] = current_time + item["updated_at"] = current_time + + # Convert data to DynamoDB compatible format + item = self._serialize_item(item) + + # Put item into DynamoDB + self.table.put_item(Item=item) + return self.read(session.session_id) + except Exception as e: + logger.error(f"Error upserting session: {e}") + return None + + def delete_session(self, session_id: Optional[str] = None): + """ + Delete a session from the database. + + Args: + session_id (Optional[str], optional): ID of the session to delete. Defaults to None. + """ + if session_id is None: + logger.warning("No session_id provided for deletion.") + return + try: + self.table.delete_item(Key={"session_id": session_id}) + logger.info(f"Successfully deleted session with session_id: {session_id}") + except Exception as e: + logger.error(f"Error deleting session: {e}") + + def drop(self) -> None: + """ + Drop the table from the database if it exists. + """ + try: + self.table.delete() + self.table.wait_until_not_exists() + logger.debug(f"Table '{self.table_name}' deleted successfully.") + except Exception as e: + logger.error(f"Error deleting table '{self.table_name}': {e}") + + def upgrade_schema(self) -> None: + """ + Upgrade the schema to the latest version. + This method is currently a placeholder and does not perform any actions. + """ + pass + + def _serialize_item(self, item: Dict[str, Any]) -> Dict[str, Any]: + """ + Serialize item to be compatible with DynamoDB. + + Args: + item (Dict[str, Any]): The item to serialize. + + Returns: + Dict[str, Any]: The serialized item. + """ + + def serialize_value(value): + if isinstance(value, float): + return Decimal(str(value)) + elif isinstance(value, dict): + return {k: serialize_value(v) for k, v in value.items()} + elif isinstance(value, list): + return [serialize_value(v) for v in value] + else: + return value + + return {k: serialize_value(v) for k, v in item.items() if v is not None} + + def _deserialize_item(self, item: Dict[str, Any]) -> Dict[str, Any]: + """ + Deserialize item from DynamoDB format. + + Args: + item (Dict[str, Any]): The item to deserialize. + + Returns: + Dict[str, Any]: The deserialized item. + """ + + def deserialize_value(value): + if isinstance(value, Decimal): + if value % 1 == 0: + return int(value) + else: + return float(value) + elif isinstance(value, dict): + return {k: deserialize_value(v) for k, v in value.items()} + elif isinstance(value, list): + return [deserialize_value(v) for v in value] + else: + return value + + return {k: deserialize_value(v) for k, v in item.items()} diff --git a/libs/agno/agno/storage/json.py b/libs/agno/agno/storage/json.py new file mode 100644 index 0000000000..a76833c073 --- /dev/null +++ b/libs/agno/agno/storage/json.py @@ -0,0 +1,141 @@ +import json +import time +from dataclasses import asdict +from pathlib import Path +from typing import List, Literal, Optional, Union + +from agno.storage.base import Storage +from agno.storage.session import Session +from agno.storage.session.agent import AgentSession +from agno.storage.session.workflow import WorkflowSession +from agno.utils.log import logger + + +class JsonStorage(Storage): + def __init__(self, dir_path: Union[str, Path], mode: Optional[Literal["agent", "workflow"]] = "agent"): + super().__init__(mode) + self.dir_path = Path(dir_path) + self.dir_path.mkdir(parents=True, exist_ok=True) + + def serialize(self, data: dict) -> str: + return json.dumps(data, ensure_ascii=False, indent=4) + + def deserialize(self, data: str) -> dict: + return json.loads(data) + + def create(self) -> None: + """Create the storage if it doesn't exist.""" + if not self.dir_path.exists(): + self.dir_path.mkdir(parents=True, exist_ok=True) + + def read(self, session_id: str, user_id: Optional[str] = None) -> Optional[Session]: + """Read an AgentSession from storage.""" + try: + with open(self.dir_path / f"{session_id}.json", "r", encoding="utf-8") as f: + data = self.deserialize(f.read()) + if user_id and data["user_id"] != user_id: + return None + if self.mode == "agent": + return AgentSession.from_dict(data) + elif self.mode == "workflow": + return WorkflowSession.from_dict(data) + except FileNotFoundError: + return None + + def get_all_session_ids(self, user_id: Optional[str] = None, entity_id: Optional[str] = None) -> List[str]: + """Get all session IDs, optionally filtered by user_id and/or entity_id.""" + session_ids = [] + for file in self.dir_path.glob("*.json"): + with open(file, "r", encoding="utf-8") as f: + data = self.deserialize(f.read()) + if user_id or entity_id: + if user_id and entity_id: + if self.mode == "agent" and data["agent_id"] == entity_id and data["user_id"] == user_id: + session_ids.append(data["session_id"]) + elif ( + self.mode == "workflow" and data["workflow_id"] == entity_id and data["user_id"] == user_id + ): + session_ids.append(data["session_id"]) + elif user_id and data["user_id"] == user_id: + session_ids.append(data["session_id"]) + elif entity_id: + if self.mode == "agent" and data["agent_id"] == entity_id: + session_ids.append(data["session_id"]) + elif self.mode == "workflow" and data["workflow_id"] == entity_id: + session_ids.append(data["session_id"]) + else: + # No filters applied, add all session_ids + session_ids.append(data["session_id"]) + return session_ids + + def get_all_sessions(self, user_id: Optional[str] = None, entity_id: Optional[str] = None) -> List[Session]: + """Get all sessions, optionally filtered by user_id and/or entity_id.""" + sessions: List[Session] = [] + for file in self.dir_path.glob("*.json"): + with open(file, "r", encoding="utf-8") as f: + data = self.deserialize(f.read()) + if user_id or entity_id: + _session: Optional[Session] = None + + if user_id and entity_id: + if self.mode == "agent" and data["agent_id"] == entity_id and data["user_id"] == user_id: + _session = AgentSession.from_dict(data) + elif ( + self.mode == "workflow" and data["workflow_id"] == entity_id and data["user_id"] == user_id + ): + _session = WorkflowSession.from_dict(data) + elif user_id and data["user_id"] == user_id: + if self.mode == "agent": + _session = AgentSession.from_dict(data) + elif self.mode == "workflow": + _session = WorkflowSession.from_dict(data) + elif entity_id: + if self.mode == "agent" and data["agent_id"] == entity_id: + _session = AgentSession.from_dict(data) + elif self.mode == "workflow" and data["workflow_id"] == entity_id: + _session = WorkflowSession.from_dict(data) + + if _session: + sessions.append(_session) + else: + # No filters applied, add all sessions + if self.mode == "agent": + _session = AgentSession.from_dict(data) + elif self.mode == "workflow": + _session = WorkflowSession.from_dict(data) + if _session: + sessions.append(_session) + return sessions + + def upsert(self, session: Session) -> Optional[Session]: + """Insert or update a Session in storage.""" + try: + data = asdict(session) + data["updated_at"] = int(time.time()) + if "created_at" not in data: + data["created_at"] = data["updated_at"] + + with open(self.dir_path / f"{session.session_id}.json", "w", encoding="utf-8") as f: + f.write(self.serialize(data)) + return session + except Exception as e: + logger.error(f"Error upserting session: {e}") + return None + + def delete_session(self, session_id: Optional[str] = None): + """Delete a session from storage.""" + if session_id is None: + return + try: + (self.dir_path / f"{session_id}.json").unlink(missing_ok=True) + except Exception as e: + logger.error(f"Error deleting session: {e}") + + def drop(self) -> None: + """Drop all sessions from storage.""" + for file in self.dir_path.glob("*.json"): + file.unlink() + + def upgrade_schema(self) -> None: + """Upgrade the schema of the storage.""" + pass diff --git a/libs/agno/agno/storage/mongodb.py b/libs/agno/agno/storage/mongodb.py new file mode 100644 index 0000000000..f47e305167 --- /dev/null +++ b/libs/agno/agno/storage/mongodb.py @@ -0,0 +1,249 @@ +from datetime import datetime, timezone +from typing import List, Literal, Optional +from uuid import UUID + +from agno.storage.base import Storage +from agno.storage.session import Session +from agno.storage.session.agent import AgentSession +from agno.storage.session.workflow import WorkflowSession +from agno.utils.log import logger + +try: + from pymongo import MongoClient + from pymongo.collection import Collection + from pymongo.database import Database + from pymongo.errors import PyMongoError +except ImportError: + raise ImportError("`pymongo` not installed. Please install it with `pip install pymongo`") + + +class MongoDbStorage(Storage): + def __init__( + self, + collection_name: str, + db_url: Optional[str] = None, + db_name: str = "agno", + client: Optional[MongoClient] = None, + mode: Optional[Literal["agent", "workflow"]] = "agent", + ): + """ + This class provides agent storage using MongoDB. + + Args: + collection_name: Name of the collection to store agent sessions + db_url: MongoDB connection URL + db_name: Name of the database + client: Optional existing MongoDB client + """ + super().__init__(mode) + self._client: Optional[MongoClient] = client + if self._client is None and db_url is not None: + self._client = MongoClient(db_url) + elif self._client is None: + self._client = MongoClient() + + if self._client is None: + raise ValueError("Must provide either db_url or client") + + self.collection_name: str = collection_name + self.db_name: str = db_name + self.db: Database = self._client[self.db_name] + self.collection: Collection = self.db[self.collection_name] + + def create(self) -> None: + """Create necessary indexes for the collection""" + try: + # Create indexes + self.collection.create_index("session_id", unique=True) + self.collection.create_index("user_id") + self.collection.create_index("created_at") + if self.mode == "agent": + self.collection.create_index("agent_id") + elif self.mode == "workflow": + self.collection.create_index("workflow_id") + except PyMongoError as e: + logger.error(f"Error creating indexes: {e}") + raise + + def read(self, session_id: str, user_id: Optional[str] = None) -> Optional[Session]: + """Read a Session from MongoDB + Args: + session_id: ID of the session to read + user_id: ID of the user to read + Returns: + Optional[Session]: The session if found, otherwise None + """ + try: + query = {"session_id": session_id} + if user_id: + query["user_id"] = user_id + + doc = self.collection.find_one(query) + if doc: + # Remove MongoDB _id before converting to AgentSession + doc.pop("_id", None) + if self.mode == "agent": + return AgentSession.from_dict(doc) + elif self.mode == "workflow": + return WorkflowSession.from_dict(doc) + return None + except PyMongoError as e: + logger.error(f"Error reading session: {e}") + return None + + def get_all_session_ids(self, user_id: Optional[str] = None, entity_id: Optional[str] = None) -> List[str]: + """Get all session IDs matching the criteria + Args: + user_id: ID of the user to read + entity_id: ID of the entity to read + Returns: + List[str]: List of session IDs + """ + try: + query = {} + if user_id is not None: + query["user_id"] = user_id + if entity_id is not None: + if self.mode == "agent": + query["agent_id"] = entity_id + elif self.mode == "workflow": + query["workflow_id"] = entity_id + + cursor = self.collection.find(query, {"session_id": 1}).sort("created_at", -1) + + return [str(doc["session_id"]) for doc in cursor] + except PyMongoError as e: + logger.error(f"Error getting session IDs: {e}") + return [] + + def get_all_sessions(self, user_id: Optional[str] = None, entity_id: Optional[str] = None) -> List[Session]: + """Get all sessions matching the criteria + Args: + user_id: ID of the user to read + entity_id: ID of the agent / workflow to read + Returns: + List[Session]: List of sessions + """ + try: + query = {} + if user_id is not None: + query["user_id"] = user_id + if entity_id is not None: + if self.mode == "agent": + query["agent_id"] = entity_id + elif self.mode == "workflow": + query["workflow_id"] = entity_id + + cursor = self.collection.find(query).sort("created_at", -1) + sessions: List[Session] = [] + for doc in cursor: + # Remove MongoDB _id before converting to AgentSession + doc.pop("_id", None) + if self.mode == "agent": + _agent_session = AgentSession.from_dict(doc) + if _agent_session is not None: + sessions.append(_agent_session) + elif self.mode == "workflow": + _workflow_session = WorkflowSession.from_dict(doc) + if _workflow_session is not None: + sessions.append(_workflow_session) + return sessions + except PyMongoError as e: + logger.error(f"Error getting sessions: {e}") + return [] + + def upsert(self, session: Session, create_and_retry: bool = True) -> Optional[Session]: + """Upsert a session + Args: + session (Session): The session to upsert + create_and_retry (bool): Whether to create a new session if the session_id already exists + Returns: + Optional[Session]: The upserted session, otherwise None + """ + try: + # Convert session to dict and add timestamps + session_dict = session.to_dict() + now = datetime.now(timezone.utc) + timestamp = int(now.timestamp()) + + # Handle UUID serialization + if isinstance(session.session_id, UUID): + session_dict["session_id"] = str(session.session_id) + + # Add version field for optimistic locking + if "_version" not in session_dict: + session_dict["_version"] = 1 + else: + session_dict["_version"] += 1 + + update_data = {**session_dict, "updated_at": timestamp} + + # For new documents, set created_at + query = {"session_id": session_dict["session_id"]} + + doc = self.collection.find_one(query) + if not doc: + update_data["created_at"] = timestamp + + result = self.collection.update_one(query, {"$set": update_data}, upsert=True) + + if result.acknowledged: + return self.read(session_id=session_dict["session_id"]) + return None + + except PyMongoError as e: + logger.error(f"Error upserting session: {e}") + return None + + def delete_session(self, session_id: Optional[str] = None) -> None: + """Delete an agent session + Args: + session_id: ID of the session to delete + Returns: + None + """ + if session_id is None: + logger.warning("No session_id provided for deletion") + return + + try: + result = self.collection.delete_one({"session_id": session_id}) + if result.deleted_count == 0: + logger.debug(f"No session found with session_id: {session_id}") + else: + logger.debug(f"Successfully deleted session with session_id: {session_id}") + except PyMongoError as e: + logger.error(f"Error deleting session: {e}") + + def drop(self) -> None: + """Drop the collection + Returns: + None + """ + try: + self.collection.drop() + except PyMongoError as e: + logger.error(f"Error dropping collection: {e}") + + def upgrade_schema(self) -> None: + """Placeholder for schema upgrades""" + pass + + def __deepcopy__(self, memo): + """Create a deep copy of the MongoDbStorage instance""" + from copy import deepcopy + + # Create a new instance without calling __init__ + cls = self.__class__ + copied_obj = cls.__new__(cls) + memo[id(self)] = copied_obj + + # Deep copy attributes + for k, v in self.__dict__.items(): + if k in {"_client", "db", "collection"}: + # Reuse MongoDB connections without copying + setattr(copied_obj, k, v) + else: + setattr(copied_obj, k, deepcopy(v, memo)) + + return copied_obj diff --git a/libs/agno/agno/storage/postgres.py b/libs/agno/agno/storage/postgres.py new file mode 100644 index 0000000000..b51147f484 --- /dev/null +++ b/libs/agno/agno/storage/postgres.py @@ -0,0 +1,484 @@ +import time +from typing import List, Literal, Optional + +from agno.storage.base import Storage +from agno.storage.session import Session +from agno.storage.session.agent import AgentSession +from agno.storage.session.workflow import WorkflowSession +from agno.utils.log import logger + +try: + from sqlalchemy.dialects import postgresql + from sqlalchemy.engine import Engine, create_engine + from sqlalchemy.inspection import inspect + from sqlalchemy.orm import scoped_session, sessionmaker + from sqlalchemy.schema import Column, MetaData, Table + from sqlalchemy.sql.expression import select, text + from sqlalchemy.types import BigInteger, String +except ImportError: + raise ImportError("`sqlalchemy` not installed. Please install it using `pip install sqlalchemy`") + + +class PostgresStorage(Storage): + def __init__( + self, + table_name: str, + schema: Optional[str] = "ai", + db_url: Optional[str] = None, + db_engine: Optional[Engine] = None, + schema_version: int = 1, + auto_upgrade_schema: bool = False, + mode: Optional[Literal["agent", "workflow"]] = "agent", + ): + """ + This class provides agent storage using a PostgreSQL table. + + The following order is used to determine the database connection: + 1. Use the db_engine if provided + 2. Use the db_url + 3. Raise an error if neither is provided + + Args: + table_name (str): Name of the table to store Agent sessions. + schema (Optional[str]): The schema to use for the table. Defaults to "ai". + db_url (Optional[str]): The database URL to connect to. + db_engine (Optional[Engine]): The SQLAlchemy database engine to use. + schema_version (int): Version of the schema. Defaults to 1. + auto_upgrade_schema (bool): Whether to automatically upgrade the schema. + mode (Optional[Literal["agent", "workflow"]]): The mode of the storage. + Raises: + ValueError: If neither db_url nor db_engine is provided. + """ + super().__init__(mode) + _engine: Optional[Engine] = db_engine + if _engine is None and db_url is not None: + _engine = create_engine(db_url) + + if _engine is None: + raise ValueError("Must provide either db_url or db_engine") + + # Database attributes + self.table_name: str = table_name + self.schema: Optional[str] = schema + self.db_url: Optional[str] = db_url + self.db_engine: Engine = _engine + self.metadata: MetaData = MetaData(schema=self.schema) + self.inspector = inspect(self.db_engine) + + # Table schema version + self.schema_version: int = schema_version + # Automatically upgrade schema if True + self.auto_upgrade_schema: bool = auto_upgrade_schema + + # Database session + self.Session: scoped_session = scoped_session(sessionmaker(bind=self.db_engine)) + # Database table for storage + self.table: Table = self.get_table() + logger.debug(f"Created PostgresStorage: '{self.schema}.{self.table_name}'") + + @property + def mode(self) -> Literal["agent", "workflow"]: + """Get the mode of the storage.""" + return super().mode + + @mode.setter + def mode(self, value: Optional[Literal["agent", "workflow"]]) -> None: + """Set the mode and refresh the table if mode changes.""" + super(PostgresStorage, type(self)).mode.fset(self, value) # type: ignore + if value is not None: + self.table = self.get_table() + + def get_table_v1(self) -> Table: + """ + Define the table schema for version 1. + + Returns: + Table: SQLAlchemy Table object representing the schema. + """ + # Common columns for both agent and workflow modes + common_columns = [ + Column("session_id", String, primary_key=True), + Column("user_id", String, index=True), + Column("memory", postgresql.JSONB), + Column("session_data", postgresql.JSONB), + Column("extra_data", postgresql.JSONB), + Column("created_at", BigInteger, server_default=text("(extract(epoch from now()))::bigint")), + Column("updated_at", BigInteger, server_onupdate=text("(extract(epoch from now()))::bigint")), + ] + + # Mode-specific columns + if self.mode == "agent": + specific_columns = [ + Column("agent_id", String, index=True), + Column("agent_data", postgresql.JSONB), + ] + else: + specific_columns = [ + Column("workflow_id", String, index=True), + Column("workflow_data", postgresql.JSONB), + ] + + # Create table with all columns + table = Table( + self.table_name, self.metadata, *common_columns, *specific_columns, extend_existing=True, schema=self.schema + ) + + return table + + def get_table(self) -> Table: + """ + Get the table schema based on the schema version. + + Returns: + Table: SQLAlchemy Table object for the current schema version. + + Raises: + ValueError: If an unsupported schema version is specified. + """ + if self.schema_version == 1: + return self.get_table_v1() + else: + raise ValueError(f"Unsupported schema version: {self.schema_version}") + + def table_exists(self) -> bool: + """ + Check if the table exists in the database. + + Returns: + bool: True if the table exists, False otherwise. + """ + try: + # Use a direct SQL query to check if the table exists + with self.Session() as sess: + if self.schema is not None: + exists_query = text( + "SELECT 1 FROM information_schema.tables WHERE table_schema = :schema AND table_name = :table" + ) + exists = ( + sess.execute(exists_query, {"schema": self.schema, "table": self.table_name}).scalar() + is not None + ) + else: + exists_query = text("SELECT 1 FROM information_schema.tables WHERE table_name = :table") + exists = sess.execute(exists_query, {"table": self.table_name}).scalar() is not None + + logger.debug(f"Table '{self.table.fullname}' does {'not' if not exists else ''} exist") + return exists + + except Exception as e: + logger.error(f"Error checking if table exists: {e}") + return False + + def create(self) -> None: + """ + Create the table if it does not exist. + """ + self.table = self.get_table() + if not self.table_exists(): + try: + with self.Session() as sess, sess.begin(): + if self.schema is not None: + logger.debug(f"Creating schema: {self.schema}") + sess.execute(text(f"CREATE SCHEMA IF NOT EXISTS {self.schema};")) + + logger.debug(f"Creating table: {self.table_name}") + + # First create the table without indexes + table_without_indexes = Table( + self.table_name, + MetaData(schema=self.schema), + *[c.copy() for c in self.table.columns], + schema=self.schema, + ) + table_without_indexes.create(self.db_engine, checkfirst=True) + + # Then create each index individually with error handling + for idx in self.table.indexes: + try: + idx_name = idx.name + logger.debug(f"Creating index: {idx_name}") + + # Check if index already exists + with self.Session() as sess: + if self.schema: + exists_query = text( + "SELECT 1 FROM pg_indexes WHERE schemaname = :schema AND indexname = :index_name" + ) + exists = ( + sess.execute(exists_query, {"schema": self.schema, "index_name": idx_name}).scalar() + is not None + ) + else: + exists_query = text("SELECT 1 FROM pg_indexes WHERE indexname = :index_name") + exists = sess.execute(exists_query, {"index_name": idx_name}).scalar() is not None + + if not exists: + idx.create(self.db_engine) + else: + logger.debug(f"Index {idx_name} already exists, skipping creation") + + except Exception as e: + # Log the error but continue with other indexes + logger.warning(f"Error creating index {idx.name}: {e}") + + except Exception as e: + logger.error(f"Could not create table: '{self.table.fullname}': {e}") + raise + + def read(self, session_id: str, user_id: Optional[str] = None) -> Optional[Session]: + """ + Read an Session from the database. + + Args: + session_id (str): ID of the session to read. + user_id (Optional[str]): User ID to filter by. Defaults to None. + + Returns: + Optional[Session]: Session object if found, None otherwise. + """ + try: + with self.Session() as sess: + stmt = select(self.table).where(self.table.c.session_id == session_id) + if user_id: + stmt = stmt.where(self.table.c.user_id == user_id) + result = sess.execute(stmt).fetchone() + if self.mode == "agent": + return AgentSession.from_dict(result._mapping) if result is not None else None + else: + return WorkflowSession.from_dict(result._mapping) if result is not None else None + except Exception as e: + if "does not exist" in str(e): + logger.debug(f"Table does not exist: {self.table.name}") + logger.debug("Creating table for future transactions") + self.create() + else: + logger.debug(f"Exception reading from table: {e}") + return None + + def get_all_session_ids(self, user_id: Optional[str] = None, entity_id: Optional[str] = None) -> List[str]: + """ + Get all session IDs, optionally filtered by user_id and/or entity_id. + + Args: + user_id (Optional[str]): The ID of the user to filter by. + entity_id (Optional[str]): The ID of the agent / workflow to filter by. + + Returns: + List[str]: List of session IDs matching the criteria. + """ + try: + with self.Session() as sess, sess.begin(): + # get all session_ids + stmt = select(self.table.c.session_id) + if user_id is not None: + stmt = stmt.where(self.table.c.user_id == user_id) + if entity_id is not None: + if self.mode == "agent": + stmt = stmt.where(self.table.c.agent_id == entity_id) + else: + stmt = stmt.where(self.table.c.workflow_id == entity_id) + + # order by created_at desc + stmt = stmt.order_by(self.table.c.created_at.desc()) + # execute query + rows = sess.execute(stmt).fetchall() + return [row[0] for row in rows] if rows is not None else [] + except Exception as e: + logger.debug(f"Exception reading from table: {e}") + logger.debug(f"Table does not exist: {self.table.name}") + logger.debug("Creating table for future transactions") + self.create() + return [] + + def get_all_sessions(self, user_id: Optional[str] = None, entity_id: Optional[str] = None) -> List[Session]: + """ + Get all sessions, optionally filtered by user_id and/or entity_id. + + Args: + user_id (Optional[str]): The ID of the user to filter by. + entity_id (Optional[str]): The ID of the agent / workflow to filter by. + + Returns: + List[Session]: List of Session objects matching the criteria. + """ + try: + with self.Session() as sess, sess.begin(): + # get all sessions + stmt = select(self.table) + if user_id is not None: + stmt = stmt.where(self.table.c.user_id == user_id) + if entity_id is not None: + if self.mode == "agent": + stmt = stmt.where(self.table.c.agent_id == entity_id) + else: + stmt = stmt.where(self.table.c.workflow_id == entity_id) + # order by created_at desc + stmt = stmt.order_by(self.table.c.created_at.desc()) + # execute query + rows = sess.execute(stmt).fetchall() + if rows is not None: + if self.mode == "agent": + return [AgentSession.from_dict(row._mapping) for row in rows] # type: ignore + else: + return [WorkflowSession.from_dict(row._mapping) for row in rows] # type: ignore + else: + return [] + except Exception as e: + logger.debug(f"Exception reading from table: {e}") + logger.debug(f"Table does not exist: {self.table.name}") + logger.debug("Creating table for future transactions") + self.create() + return [] + + def upsert(self, session: Session, create_and_retry: bool = True) -> Optional[Session]: + """ + Insert or update an Session in the database. + + Args: + session (Session): The session data to upsert. + create_and_retry (bool): Retry upsert if table does not exist. + + Returns: + Optional[Session]: The upserted Session, or None if operation failed. + """ + try: + with self.Session() as sess, sess.begin(): + # Create an insert statement + if self.mode == "agent": + stmt = postgresql.insert(self.table).values( + session_id=session.session_id, + agent_id=session.agent_id, # type: ignore + user_id=session.user_id, + memory=session.memory, + agent_data=session.agent_data, # type: ignore + session_data=session.session_data, + extra_data=session.extra_data, + ) + # Define the upsert if the session_id already exists + # See: https://docs.sqlalchemy.org/en/20/dialects/postgresql.html#postgresql-insert-on-conflict + stmt = stmt.on_conflict_do_update( + index_elements=["session_id"], + set_=dict( + agent_id=session.agent_id, # type: ignore + user_id=session.user_id, + memory=session.memory, + agent_data=session.agent_data, # type: ignore + session_data=session.session_data, + extra_data=session.extra_data, + updated_at=int(time.time()), + ), # The updated value for each column + ) + else: + stmt = postgresql.insert(self.table).values( + session_id=session.session_id, + workflow_id=session.workflow_id, # type: ignore + user_id=session.user_id, + memory=session.memory, + workflow_data=session.workflow_data, # type: ignore + session_data=session.session_data, + extra_data=session.extra_data, + ) + # Define the upsert if the session_id already exists + # See: https://docs.sqlalchemy.org/en/20/dialects/postgresql.html#postgresql-insert-on-conflict + stmt = stmt.on_conflict_do_update( + index_elements=["session_id"], + set_=dict( + workflow_id=session.workflow_id, # type: ignore + user_id=session.user_id, + memory=session.memory, + workflow_data=session.workflow_data, # type: ignore + session_data=session.session_data, + extra_data=session.extra_data, + updated_at=int(time.time()), + ), # The updated value for each column + ) + + sess.execute(stmt) + except Exception as e: + logger.debug(f"Exception upserting into table: {e}") + if create_and_retry and not self.table_exists(): + logger.debug(f"Table does not exist: {self.table.name}") + logger.debug("Creating table and retrying upsert") + self.create() + return self.upsert(session, create_and_retry=False) + return None + return self.read(session_id=session.session_id) + + def delete_session(self, session_id: Optional[str] = None): + """ + Delete a session from the database. + + Args: + session_id (Optional[str], optional): ID of the session to delete. Defaults to None. + + Raises: + Exception: If an error occurs during deletion. + """ + if session_id is None: + logger.warning("No session_id provided for deletion.") + return + + try: + with self.Session() as sess, sess.begin(): + # Delete the session with the given session_id + delete_stmt = self.table.delete().where(self.table.c.session_id == session_id) + result = sess.execute(delete_stmt) + if result.rowcount == 0: + logger.debug(f"No session found with session_id: {session_id}") + else: + logger.debug(f"Successfully deleted session with session_id: {session_id}") + except Exception as e: + logger.error(f"Error deleting session: {e}") + + def drop(self) -> None: + """ + Drop the table from the database if it exists. + """ + if self.table_exists(): + logger.debug(f"Deleting table: {self.table_name}") + # Drop with checkfirst=True to avoid errors if the table doesn't exist + self.table.drop(self.db_engine, checkfirst=True) + # Clear metadata to ensure indexes are recreated properly + self.metadata = MetaData(schema=self.schema) + self.table = self.get_table() + + def upgrade_schema(self) -> None: + """ + Upgrade the schema to the latest version. + This method is currently a placeholder and does not perform any actions. + """ + pass + + def __deepcopy__(self, memo): + """ + Create a deep copy of the PostgresStorage instance, handling unpickleable attributes. + + Args: + memo (dict): A dictionary of objects already copied during the current copying pass. + + Returns: + PostgresStorage: A deep-copied instance of PostgresStorage. + """ + from copy import deepcopy + + # Create a new instance without calling __init__ + cls = self.__class__ + copied_obj = cls.__new__(cls) + memo[id(self)] = copied_obj + + # Deep copy attributes + for k, v in self.__dict__.items(): + if k in {"metadata", "table", "inspector"}: + continue + # Reuse db_engine and Session without copying + elif k in {"db_engine", "Session"}: + setattr(copied_obj, k, v) + else: + setattr(copied_obj, k, deepcopy(v, memo)) + + # Recreate metadata and table for the copied instance + copied_obj.metadata = MetaData(schema=copied_obj.schema) + copied_obj.inspector = inspect(copied_obj.db_engine) + copied_obj.table = copied_obj.get_table() + + return copied_obj diff --git a/libs/agno/agno/storage/session/__init__.py b/libs/agno/agno/storage/session/__init__.py new file mode 100644 index 0000000000..286c4d4d2f --- /dev/null +++ b/libs/agno/agno/storage/session/__init__.py @@ -0,0 +1,6 @@ +from typing import Union + +from agno.storage.session.agent import AgentSession +from agno.storage.session.workflow import WorkflowSession + +Session = Union[AgentSession, WorkflowSession] diff --git a/libs/agno/agno/storage/agent/session.py b/libs/agno/agno/storage/session/agent.py similarity index 99% rename from libs/agno/agno/storage/agent/session.py rename to libs/agno/agno/storage/session/agent.py index ddd32b7c67..a7b87c80bb 100644 --- a/libs/agno/agno/storage/agent/session.py +++ b/libs/agno/agno/storage/session/agent.py @@ -12,14 +12,10 @@ class AgentSession: # Session UUID session_id: str - # ID of the agent that this session is associated with - agent_id: Optional[str] = None # ID of the user interacting with this agent user_id: Optional[str] = None # Agent Memory memory: Optional[Dict[str, Any]] = None - # Agent Data: agent_id, name and model - agent_data: Optional[Dict[str, Any]] = None # Session Data: session_name, session_state, images, videos, audio session_data: Optional[Dict[str, Any]] = None # Extra Data stored with this agent @@ -29,6 +25,11 @@ class AgentSession: # The unix timestamp when this session was last updated updated_at: Optional[int] = None + # ID of the agent that this session is associated with + agent_id: Optional[str] = None + # Agent Data: agent_id, name and model + agent_data: Optional[Dict[str, Any]] = None + def to_dict(self) -> Dict[str, Any]: return asdict(self) diff --git a/libs/agno/agno/storage/workflow/session.py b/libs/agno/agno/storage/session/workflow.py similarity index 91% rename from libs/agno/agno/storage/workflow/session.py rename to libs/agno/agno/storage/session/workflow.py index e5c60275fa..1803c229bb 100644 --- a/libs/agno/agno/storage/workflow/session.py +++ b/libs/agno/agno/storage/session/workflow.py @@ -12,23 +12,24 @@ class WorkflowSession: # Session UUID session_id: str - # ID of the workflow that this session is associated with - workflow_id: Optional[str] = None - # ID of the user interacting with this workflow + # ID of the user interacting with this agent user_id: Optional[str] = None - # Workflow Memory + # Agent Memory memory: Optional[Dict[str, Any]] = None - # Workflow Data - workflow_data: Optional[Dict[str, Any]] = None - # Session Data + # Session Data: session_name, session_state, images, videos, audio session_data: Optional[Dict[str, Any]] = None - # Extra Data stored with this workflow + # Extra Data stored with this agent extra_data: Optional[Dict[str, Any]] = None # The unix timestamp when this session was created created_at: Optional[int] = None # The unix timestamp when this session was last updated updated_at: Optional[int] = None + # ID of the workflow that this session is associated with + workflow_id: Optional[str] = None + # Workflow Data + workflow_data: Optional[Dict[str, Any]] = None + def to_dict(self) -> Dict[str, Any]: return asdict(self) diff --git a/libs/agno/agno/storage/singlestore.py b/libs/agno/agno/storage/singlestore.py new file mode 100644 index 0000000000..964a6af41e --- /dev/null +++ b/libs/agno/agno/storage/singlestore.py @@ -0,0 +1,404 @@ +import json +from typing import Any, List, Literal, Optional + +from agno.storage.base import Storage +from agno.storage.session import Session +from agno.storage.session.agent import AgentSession +from agno.storage.session.workflow import WorkflowSession +from agno.utils.log import logger + +try: + from sqlalchemy.dialects import mysql + from sqlalchemy.engine import Engine, create_engine + from sqlalchemy.engine.row import Row + from sqlalchemy.inspection import inspect + from sqlalchemy.orm import Session as SqlSession + from sqlalchemy.orm import sessionmaker + from sqlalchemy.schema import Column, MetaData, Table + from sqlalchemy.sql.expression import select, text +except ImportError: + raise ImportError("`sqlalchemy` not installed") + + +class SingleStoreStorage(Storage): + def __init__( + self, + table_name: str, + schema: Optional[str] = "ai", + db_url: Optional[str] = None, + db_engine: Optional[Engine] = None, + schema_version: int = 1, + auto_upgrade_schema: bool = False, + mode: Optional[Literal["agent", "workflow"]] = "agent", + ): + """ + This class provides Agent storage using a singlestore table. + + The following order is used to determine the database connection: + 1. Use the db_engine if provided + 2. Use the db_url if provided + + Args: + table_name (str): The name of the table to store the agent data. + schema (Optional[str], optional): The schema of the table. Defaults to "ai". + db_url (Optional[str], optional): The database URL. Defaults to None. + db_engine (Optional[Engine], optional): The database engine. Defaults to None. + schema_version (int, optional): The schema version. Defaults to 1. + auto_upgrade_schema (bool, optional): Automatically upgrade the schema. Defaults to False. + mode (Optional[Literal["agent", "workflow"]], optional): The mode of the storage. Defaults to "agent". + """ + super().__init__(mode) + _engine: Optional[Engine] = db_engine + if _engine is None and db_url is not None: + _engine = create_engine(db_url, connect_args={"charset": "utf8mb4"}) + + if _engine is None: + raise ValueError("Must provide either db_url or db_engine") + + # Database attributes + self.table_name: str = table_name + self.schema: Optional[str] = schema + self.db_url: Optional[str] = db_url + self.db_engine: Engine = _engine + self.metadata: MetaData = MetaData(schema=self.schema) + + # Table schema version + self.schema_version: int = schema_version + # Automatically upgrade schema if True + self.auto_upgrade_schema: bool = auto_upgrade_schema + + # Database session + self.SqlSession: sessionmaker[SqlSession] = sessionmaker(bind=self.db_engine) + # Database table for storage + self.table: Table = self.get_table() + + @property + def mode(self) -> Literal["agent", "workflow"]: + """Get the mode of the storage.""" + return super().mode + + @mode.setter + def mode(self, value: Optional[Literal["agent", "workflow"]]) -> None: + """Set the mode and refresh the table if mode changes.""" + super(SingleStoreStorage, type(self)).mode.fset(self, value) # type: ignore + if value is not None: + self.table = self.get_table() + + def get_table_v1(self) -> Table: + common_columns = [ + Column("session_id", mysql.TEXT, primary_key=True), + Column("user_id", mysql.TEXT), + Column("memory", mysql.JSON), + Column("session_data", mysql.JSON), + Column("extra_data", mysql.JSON), + Column("created_at", mysql.BIGINT), + Column("updated_at", mysql.BIGINT), + ] + + if self.mode == "agent": + specific_columns = [ + Column("agent_id", mysql.TEXT), + Column("agent_data", mysql.JSON), + ] + else: + specific_columns = [ + Column("workflow_id", mysql.TEXT), + Column("workflow_data", mysql.JSON), + ] + + # Create table with all columns + table = Table( + self.table_name, self.metadata, *common_columns, *specific_columns, extend_existing=True, schema=self.schema + ) + + return table + + def get_table(self) -> Table: + if self.schema_version == 1: + return self.get_table_v1() + else: + raise ValueError(f"Unsupported schema version: {self.schema_version}") + + def table_exists(self) -> bool: + logger.debug(f"Checking if table exists: {self.table.name}") + try: + return inspect(self.db_engine).has_table(self.table.name, schema=self.schema) + except Exception as e: + logger.error(e) + return False + + def create(self) -> None: + self.table = self.get_table() + if not self.table_exists(): + logger.info(f"\nCreating table: {self.table_name}\n") + self.table.create(self.db_engine) + + def _read(self, session: SqlSession, session_id: str, user_id: Optional[str] = None) -> Optional[Row[Any]]: + stmt = select(self.table).where(self.table.c.session_id == session_id) + if user_id is not None: + stmt = stmt.where(self.table.c.user_id == user_id) + try: + return session.execute(stmt).first() + except Exception as e: + logger.debug(f"Exception reading from table: {e}") + logger.debug(f"Table does not exist: {self.table.name}") + logger.debug(f"Creating table: {self.table_name}") + self.create() + return None + + def read(self, session_id: str, user_id: Optional[str] = None) -> Optional[Session]: + with self.SqlSession.begin() as sess: + existing_row: Optional[Row[Any]] = self._read(session=sess, session_id=session_id, user_id=user_id) + if existing_row is not None: + if self.mode == "agent": + return AgentSession.from_dict(existing_row._mapping) # type: ignore + else: + return WorkflowSession.from_dict(existing_row._mapping) # type: ignore + return None + + def get_all_session_ids(self, user_id: Optional[str] = None, entity_id: Optional[str] = None) -> List[str]: + session_ids: List[str] = [] + try: + with self.SqlSession.begin() as sess: + # get all session_ids for this user + stmt = select(self.table) + if user_id is not None: + stmt = stmt.where(self.table.c.user_id == user_id) + if entity_id is not None: + if self.mode == "agent": + stmt = stmt.where(self.table.c.agent_id == entity_id) + else: + stmt = stmt.where(self.table.c.workflow_id == entity_id) + # order by created_at desc + stmt = stmt.order_by(self.table.c.created_at.desc()) + # execute query + rows = sess.execute(stmt).fetchall() + for row in rows: + if row is not None and row.session_id is not None: + session_ids.append(row.session_id) + except Exception as e: + logger.error(f"An unexpected error occurred: {str(e)}") + return session_ids + + def get_all_sessions(self, user_id: Optional[str] = None, entity_id: Optional[str] = None) -> List[Session]: + sessions: List[Session] = [] + try: + with self.SqlSession.begin() as sess: + # get all sessions for this user + stmt = select(self.table) + if user_id is not None: + stmt = stmt.where(self.table.c.user_id == user_id) + if entity_id is not None: + if self.mode == "agent": + stmt = stmt.where(self.table.c.agent_id == entity_id) + else: + stmt = stmt.where(self.table.c.workflow_id == entity_id) + # order by created_at desc + stmt = stmt.order_by(self.table.c.created_at.desc()) + # execute query + rows = sess.execute(stmt).fetchall() + for row in rows: + if row.session_id is not None: + if self.mode == "agent": + _agent_session = AgentSession.from_dict(row._mapping) # type: ignore + if _agent_session is not None: + sessions.append(_agent_session) + else: + _workflow_session = WorkflowSession.from_dict(row._mapping) # type: ignore + if _workflow_session is not None: + sessions.append(_workflow_session) + except Exception: + logger.debug(f"Table does not exist: {self.table.name}") + return sessions + + def upsert(self, session: Session) -> Optional[Session]: + """ + Create a new session if it does not exist, otherwise update the existing session. + """ + + with self.SqlSession.begin() as sess: + # Create an insert statement using MySQL's ON DUPLICATE KEY UPDATE syntax + if self.mode == "agent": + upsert_sql = text( + f""" + INSERT INTO {self.schema}.{self.table_name} + (session_id, agent_id, user_id, memory, agent_data, session_data, extra_data, created_at, updated_at) + VALUES + (:session_id, :agent_id, :user_id, :memory, :agent_data, :session_data, :extra_data, UNIX_TIMESTAMP(), NULL) + ON DUPLICATE KEY UPDATE + agent_id = VALUES(agent_id), + user_id = VALUES(user_id), + memory = VALUES(memory), + agent_data = VALUES(agent_data), + session_data = VALUES(session_data), + extra_data = VALUES(extra_data), + updated_at = UNIX_TIMESTAMP(); + """ + ) + else: + upsert_sql = text( + f""" + INSERT INTO {self.schema}.{self.table_name} + (session_id, workflow_id, user_id, memory, workflow_data, session_data, extra_data, created_at, updated_at) + VALUES + (:session_id, :workflow_id, :user_id, :memory, :workflow_data, :session_data, :extra_data, UNIX_TIMESTAMP(), NULL) + ON DUPLICATE KEY UPDATE + workflow_id = VALUES(workflow_id), + user_id = VALUES(user_id), + memory = VALUES(memory), + workflow_data = VALUES(workflow_data), + session_data = VALUES(session_data), + extra_data = VALUES(extra_data), + updated_at = UNIX_TIMESTAMP(); + """ + ) + + try: + if self.mode == "agent": + sess.execute( + upsert_sql, + { + "session_id": session.session_id, + "agent_id": session.agent_id, # type: ignore + "user_id": session.user_id, + "memory": json.dumps(session.memory, ensure_ascii=False) + if session.memory is not None + else None, + "agent_data": json.dumps(session.agent_data, ensure_ascii=False) # type: ignore + if session.agent_data is not None # type: ignore + else None, + "session_data": json.dumps(session.session_data, ensure_ascii=False) + if session.session_data is not None + else None, + "extra_data": json.dumps(session.extra_data, ensure_ascii=False) + if session.extra_data is not None + else None, + }, + ) + else: + sess.execute( + upsert_sql, + { + "session_id": session.session_id, + "workflow_id": session.workflow_id, # type: ignore + "user_id": session.user_id, + "memory": json.dumps(session.memory, ensure_ascii=False) + if session.memory is not None + else None, + "workflow_data": json.dumps(session.workflow_data, ensure_ascii=False) # type: ignore + if session.workflow_data is not None # type: ignore + else None, + "session_data": json.dumps(session.session_data, ensure_ascii=False) + if session.session_data is not None + else None, + "extra_data": json.dumps(session.extra_data, ensure_ascii=False) + if session.extra_data is not None + else None, + }, + ) + except Exception: + # Create table and try again + self.create() + if self.mode == "agent": + sess.execute( + upsert_sql, + { + "session_id": session.session_id, + "agent_id": session.agent_id, # type: ignore + "user_id": session.user_id, + "memory": json.dumps(session.memory, ensure_ascii=False) + if session.memory is not None + else None, + "agent_data": json.dumps(session.agent_data, ensure_ascii=False) # type: ignore + if session.agent_data is not None # type: ignore + else None, + "session_data": json.dumps(session.session_data, ensure_ascii=False) + if session.session_data is not None + else None, + "extra_data": json.dumps(session.extra_data, ensure_ascii=False) + if session.extra_data is not None + else None, + }, + ) + else: + sess.execute( + upsert_sql, + { + "session_id": session.session_id, + "workflow_id": session.workflow_id, # type: ignore + "user_id": session.user_id, + "memory": json.dumps(session.memory, ensure_ascii=False) + if session.memory is not None + else None, + "workflow_data": json.dumps(session.workflow_data, ensure_ascii=False) # type: ignore + if session.workflow_data is not None # type: ignore + else None, + "session_data": json.dumps(session.session_data, ensure_ascii=False) + if session.session_data is not None + else None, + "extra_data": json.dumps(session.extra_data, ensure_ascii=False) + if session.extra_data is not None + else None, + }, + ) + return self.read(session_id=session.session_id) + + def delete_session(self, session_id: Optional[str] = None): + if session_id is None: + logger.warning("No session_id provided for deletion.") + return + + with self.SqlSession() as sess, sess.begin(): + try: + # Delete the session with the given session_id + delete_stmt = self.table.delete().where(self.table.c.session_id == session_id) + result = sess.execute(delete_stmt) + + if result.rowcount == 0: + logger.warning(f"No session found with session_id: {session_id}") + else: + logger.info(f"Successfully deleted session with session_id: {session_id}") + except Exception as e: + logger.error(f"Error deleting session: {e}") + raise + + def drop(self) -> None: + if self.table_exists(): + logger.info(f"Deleting table: {self.table_name}") + self.table.drop(self.db_engine) + + def upgrade_schema(self) -> None: + pass + + def __deepcopy__(self, memo): + """ + Create a deep copy of the SingleStoreAgentStorage instance, handling unpickleable attributes. + + Args: + memo (dict): A dictionary of objects already copied during the current copying pass. + + Returns: + SingleStoreStorage: A deep-copied instance of SingleStoreAgentStorage. + """ + from copy import deepcopy + + # Create a new instance without calling __init__ + cls = self.__class__ + copied_obj = cls.__new__(cls) + memo[id(self)] = copied_obj + + # Deep copy attributes + for k, v in self.__dict__.items(): + if k in {"metadata", "table"}: + continue + # Reuse db_engine and Session without copying + elif k in {"db_engine", "Session"}: + setattr(copied_obj, k, v) + else: + setattr(copied_obj, k, deepcopy(v, memo)) + + # Recreate metadata and table for the copied instance + copied_obj.metadata = MetaData(schema=self.schema) + copied_obj.table = copied_obj.get_table() + + return copied_obj diff --git a/libs/agno/agno/storage/sqlite.py b/libs/agno/agno/storage/sqlite.py new file mode 100644 index 0000000000..c6440b2886 --- /dev/null +++ b/libs/agno/agno/storage/sqlite.py @@ -0,0 +1,471 @@ +import time +from pathlib import Path +from typing import List, Literal, Optional + +from agno.storage.base import Storage +from agno.storage.session import Session +from agno.storage.session.agent import AgentSession +from agno.storage.session.workflow import WorkflowSession +from agno.utils.log import logger + +try: + from sqlalchemy.dialects import sqlite + from sqlalchemy.engine import Engine, create_engine + from sqlalchemy.inspection import inspect + from sqlalchemy.orm import Session as SqlSession + from sqlalchemy.orm import sessionmaker + from sqlalchemy.schema import Column, MetaData, Table + from sqlalchemy.sql import text + from sqlalchemy.sql.expression import select + from sqlalchemy.types import String +except ImportError: + raise ImportError("`sqlalchemy` not installed. Please install it using `pip install sqlalchemy`") + + +class SqliteStorage(Storage): + def __init__( + self, + table_name: str, + db_url: Optional[str] = None, + db_file: Optional[str] = None, + db_engine: Optional[Engine] = None, + schema_version: int = 1, + auto_upgrade_schema: bool = False, + mode: Optional[Literal["agent", "workflow"]] = "agent", + ): + """ + This class provides agent storage using a sqlite database. + + The following order is used to determine the database connection: + 1. Use the db_engine if provided + 2. Use the db_url + 3. Use the db_file + 4. Create a new in-memory database + + Args: + table_name: The name of the table to store Agent sessions. + db_url: The database URL to connect to. + db_file: The database file to connect to. + db_engine: The SQLAlchemy database engine to use. + """ + super().__init__(mode) + _engine: Optional[Engine] = db_engine + if _engine is None and db_url is not None: + _engine = create_engine(db_url) + elif _engine is None and db_file is not None: + # Use the db_file to create the engine + db_path = Path(db_file).resolve() + # Ensure the directory exists + db_path.parent.mkdir(parents=True, exist_ok=True) + _engine = create_engine(f"sqlite:///{db_path}") + else: + _engine = create_engine("sqlite://") + + if _engine is None: + raise ValueError("Must provide either db_url, db_file or db_engine") + + # Database attributes + self.table_name: str = table_name + self.db_url: Optional[str] = db_url + self.db_engine: Engine = _engine + self.metadata: MetaData = MetaData() + self.inspector = inspect(self.db_engine) + + # Table schema version + self.schema_version: int = schema_version + # Automatically upgrade schema if True + self.auto_upgrade_schema: bool = auto_upgrade_schema + + # Database session + self.SqlSession: sessionmaker[SqlSession] = sessionmaker(bind=self.db_engine) + # Database table for storage + self.table: Table = self.get_table() + + @property + def mode(self) -> Optional[Literal["agent", "workflow"]]: + """Get the mode of the storage.""" + return super().mode + + @mode.setter + def mode(self, value: Optional[Literal["agent", "workflow"]]) -> None: + """Set the mode and refresh the table if mode changes.""" + super(SqliteStorage, type(self)).mode.fset(self, value) # type: ignore + if value is not None: + self.table = self.get_table() + + def get_table_v1(self) -> Table: + """ + Define the table schema for version 1. + + Returns: + Table: SQLAlchemy Table object representing the schema. + """ + common_columns = [ + Column("session_id", String, primary_key=True), + Column("user_id", String, index=True), + Column("memory", sqlite.JSON), + Column("session_data", sqlite.JSON), + Column("extra_data", sqlite.JSON), + Column("created_at", sqlite.INTEGER, default=lambda: int(time.time())), + Column("updated_at", sqlite.INTEGER, onupdate=lambda: int(time.time())), + ] + + # Mode-specific columns + if self.mode == "agent": + specific_columns = [ + Column("agent_id", String, index=True), + Column("agent_data", sqlite.JSON), + ] + else: + specific_columns = [ + Column("workflow_id", String, index=True), + Column("workflow_data", sqlite.JSON), + ] + + # Create table with all columns + table = Table( + self.table_name, + self.metadata, + *common_columns, + *specific_columns, + extend_existing=True, + sqlite_autoincrement=True, + ) + + return table + + def get_table(self) -> Table: + """ + Get the table schema based on the schema version. + + Returns: + Table: SQLAlchemy Table object for the current schema version. + + Raises: + ValueError: If an unsupported schema version is specified. + """ + if self.schema_version == 1: + return self.get_table_v1() + else: + raise ValueError(f"Unsupported schema version: {self.schema_version}") + + def table_exists(self) -> bool: + """ + Check if the table exists in the database. + + Returns: + bool: True if the table exists, False otherwise. + """ + try: + # For SQLite, we need to check the sqlite_master table + with self.SqlSession() as sess: + result = sess.execute( + text("SELECT name FROM sqlite_master WHERE type='table' AND name=:table_name"), + {"table_name": self.table_name}, + ).scalar() + return result is not None + except Exception as e: + logger.error(f"Error checking if table exists: {e}") + return False + + def create(self) -> None: + """ + Create the table if it doesn't exist. + """ + self.table = self.get_table() + if not self.table_exists(): + logger.debug(f"Creating table: {self.table.name}") + try: + # First create the table without indexes + table_without_indexes = Table( + self.table_name, + MetaData(), + *[c.copy() for c in self.table.columns], + ) + table_without_indexes.create(self.db_engine, checkfirst=True) + + # Then create each index individually with error handling + for idx in self.table.indexes: + try: + idx_name = idx.name + logger.debug(f"Creating index: {idx_name}") + + # Check if index already exists using SQLite's schema table + with self.SqlSession() as sess: + exists_query = text("SELECT 1 FROM sqlite_master WHERE type='index' AND name=:index_name") + exists = sess.execute(exists_query, {"index_name": idx_name}).scalar() is not None + + if not exists: + idx.create(self.db_engine) + else: + logger.debug(f"Index {idx_name} already exists, skipping creation") + + except Exception as e: + # Log the error but continue with other indexes + logger.warning(f"Error creating index {idx.name}: {e}") + + except Exception as e: + logger.error(f"Error creating table: {e}") + raise + + def read(self, session_id: str, user_id: Optional[str] = None) -> Optional[Session]: + """ + Read a Session from the database. + + Args: + session_id (str): ID of the session to read. + user_id (Optional[str]): User ID to filter by. Defaults to None. + + Returns: + Optional[Session]: Session object if found, None otherwise. + """ + try: + with self.SqlSession() as sess: + stmt = select(self.table).where(self.table.c.session_id == session_id) + if user_id: + stmt = stmt.where(self.table.c.user_id == user_id) + result = sess.execute(stmt).fetchone() + if self.mode == "agent": + return AgentSession.from_dict(result._mapping) if result is not None else None # type: ignore + else: + return WorkflowSession.from_dict(result._mapping) if result is not None else None # type: ignore + except Exception as e: + if "no such table" in str(e): + logger.debug(f"Table does not exist: {self.table.name}") + self.create() + else: + logger.debug(f"Exception reading from table: {e}") + return None + + def get_all_session_ids(self, user_id: Optional[str] = None, entity_id: Optional[str] = None) -> List[str]: + """ + Get all session IDs, optionally filtered by user_id and/or entity_id. + + Args: + user_id (Optional[str]): The ID of the user to filter by. + entity_id (Optional[str]): The ID of the agent / workflow to filter by. + + Returns: + List[str]: List of session IDs matching the criteria. + """ + try: + with self.SqlSession() as sess, sess.begin(): + # get all session_ids + stmt = select(self.table.c.session_id) + if user_id is not None: + stmt = stmt.where(self.table.c.user_id == user_id) + if entity_id is not None: + if self.mode == "agent": + stmt = stmt.where(self.table.c.agent_id == entity_id) + else: + stmt = stmt.where(self.table.c.workflow_id == entity_id) + # order by created_at desc + stmt = stmt.order_by(self.table.c.created_at.desc()) + # execute query + rows = sess.execute(stmt).fetchall() + return [row[0] for row in rows] if rows is not None else [] + except Exception as e: + if "no such table" in str(e): + logger.debug(f"Table does not exist: {self.table.name}") + self.create() + else: + logger.debug(f"Exception reading from table: {e}") + return [] + + def get_all_sessions(self, user_id: Optional[str] = None, entity_id: Optional[str] = None) -> List[Session]: + """ + Get all sessions, optionally filtered by user_id and/or entity_id. + + Args: + user_id (Optional[str]): The ID of the user to filter by. + entity_id (Optional[str]): The ID of the agent / workflow to filter by. + + Returns: + List[Session]: List of Session objects matching the criteria. + """ + try: + with self.SqlSession() as sess, sess.begin(): + # get all sessions + stmt = select(self.table) + if user_id is not None: + stmt = stmt.where(self.table.c.user_id == user_id) + if entity_id is not None: + if self.mode == "agent": + stmt = stmt.where(self.table.c.agent_id == entity_id) + else: + stmt = stmt.where(self.table.c.workflow_id == entity_id) + # order by created_at desc + stmt = stmt.order_by(self.table.c.created_at.desc()) + # execute query + rows = sess.execute(stmt).fetchall() + if rows is not None: + if self.mode == "agent": + return [AgentSession.from_dict(row._mapping) for row in rows] # type: ignore + else: + return [WorkflowSession.from_dict(row._mapping) for row in rows] # type: ignore + else: + return [] + except Exception as e: + if "no such table" in str(e): + logger.debug(f"Table does not exist: {self.table.name}") + self.create() + else: + logger.debug(f"Exception reading from table: {e}") + return [] + + def upsert(self, session: Session, create_and_retry: bool = True) -> Optional[Session]: + """ + Insert or update a Session in the database. + + Args: + session (Session): The session data to upsert. + create_and_retry (bool): Retry upsert if table does not exist. + + Returns: + Optional[Session]: The upserted Session, or None if operation failed. + """ + try: + with self.SqlSession() as sess, sess.begin(): + if self.mode == "agent": + # Create an insert statement + stmt = sqlite.insert(self.table).values( + session_id=session.session_id, + agent_id=session.agent_id, # type: ignore + user_id=session.user_id, + memory=session.memory, + agent_data=session.agent_data, # type: ignore + session_data=session.session_data, + extra_data=session.extra_data, + ) + + # Define the upsert if the session_id already exists + # See: https://docs.sqlalchemy.org/en/20/dialects/sqlite.html#insert-on-conflict-upsert + stmt = stmt.on_conflict_do_update( + index_elements=["session_id"], + set_=dict( + agent_id=session.agent_id, # type: ignore + user_id=session.user_id, + memory=session.memory, + agent_data=session.agent_data, # type: ignore + session_data=session.session_data, + extra_data=session.extra_data, + updated_at=int(time.time()), + ), # The updated value for each column + ) + else: + # Create an insert statement + stmt = sqlite.insert(self.table).values( + session_id=session.session_id, + workflow_id=session.workflow_id, # type: ignore + user_id=session.user_id, + memory=session.memory, + workflow_data=session.workflow_data, # type: ignore + session_data=session.session_data, + extra_data=session.extra_data, + ) + + # Define the upsert if the session_id already exists + # See: https://docs.sqlalchemy.org/en/20/dialects/sqlite.html#insert-on-conflict-upsert + stmt = stmt.on_conflict_do_update( + index_elements=["session_id"], + set_=dict( + workflow_id=session.workflow_id, # type: ignore + user_id=session.user_id, + memory=session.memory, + workflow_data=session.workflow_data, # type: ignore + session_data=session.session_data, + extra_data=session.extra_data, + updated_at=int(time.time()), + ), # The updated value for each column + ) + + sess.execute(stmt) + except Exception as e: + if create_and_retry and not self.table_exists(): + logger.debug(f"Table does not exist: {self.table.name}") + logger.debug("Creating table and retrying upsert") + self.create() + return self.upsert(session, create_and_retry=False) + else: + logger.debug(f"Exception upserting into table: {e}") + return None + return self.read(session_id=session.session_id) + + def delete_session(self, session_id: Optional[str] = None): + """ + Delete a workflow session from the database. + + Args: + session_id (Optional[str]): The ID of the session to delete. + + Raises: + ValueError: If session_id is not provided. + """ + if session_id is None: + logger.warning("No session_id provided for deletion.") + return + + try: + with self.SqlSession() as sess, sess.begin(): + # Delete the session with the given session_id + delete_stmt = self.table.delete().where(self.table.c.session_id == session_id) + result = sess.execute(delete_stmt) + if result.rowcount == 0: + logger.debug(f"No session found with session_id: {session_id}") + else: + logger.debug(f"Successfully deleted session with session_id: {session_id}") + except Exception as e: + logger.error(f"Error deleting session: {e}") + + def drop(self) -> None: + """ + Drop the table from the database if it exists. + """ + if self.table_exists(): + logger.debug(f"Deleting table: {self.table_name}") + # Drop with checkfirst=True to avoid errors if the table doesn't exist + self.table.drop(self.db_engine, checkfirst=True) + # Clear metadata to ensure indexes are recreated properly + self.metadata = MetaData() + self.table = self.get_table() + + def upgrade_schema(self) -> None: + """ + Upgrade the schema of the workflow storage table. + This method is currently a placeholder and does not perform any actions. + """ + pass + + def __deepcopy__(self, memo): + """ + Create a deep copy of the SqliteAgentStorage instance, handling unpickleable attributes. + + Args: + memo (dict): A dictionary of objects already copied during the current copying pass. + + Returns: + SqliteStorage: A deep-copied instance of SqliteAgentStorage. + """ + from copy import deepcopy + + # Create a new instance without calling __init__ + cls = self.__class__ + copied_obj = cls.__new__(cls) + memo[id(self)] = copied_obj + + # Deep copy attributes + for k, v in self.__dict__.items(): + if k in {"metadata", "table", "inspector"}: + continue + # Reuse db_engine and Session without copying + elif k in {"db_engine", "Session"}: + setattr(copied_obj, k, v) + else: + setattr(copied_obj, k, deepcopy(v, memo)) + + # Recreate metadata and table for the copied instance + copied_obj.metadata = MetaData() + copied_obj.inspector = inspect(copied_obj.db_engine) + copied_obj.table = copied_obj.get_table() + + return copied_obj diff --git a/libs/agno/agno/storage/workflow/base.py b/libs/agno/agno/storage/workflow/base.py deleted file mode 100644 index d3a1a4906c..0000000000 --- a/libs/agno/agno/storage/workflow/base.py +++ /dev/null @@ -1,40 +0,0 @@ -from abc import ABC, abstractmethod -from typing import List, Optional - -from agno.storage.workflow.session import WorkflowSession - - -class WorkflowStorage(ABC): - @abstractmethod - def create(self) -> None: - raise NotImplementedError - - @abstractmethod - def read(self, session_id: str, user_id: Optional[str] = None) -> Optional[WorkflowSession]: - raise NotImplementedError - - @abstractmethod - def get_all_session_ids(self, user_id: Optional[str] = None, workflow_id: Optional[str] = None) -> List[str]: - raise NotImplementedError - - @abstractmethod - def get_all_sessions( - self, user_id: Optional[str] = None, workflow_id: Optional[str] = None - ) -> List[WorkflowSession]: - raise NotImplementedError - - @abstractmethod - def upsert(self, session: WorkflowSession) -> Optional[WorkflowSession]: - raise NotImplementedError - - @abstractmethod - def delete_session(self, session_id: Optional[str] = None): - raise NotImplementedError - - @abstractmethod - def drop(self) -> None: - raise NotImplementedError - - @abstractmethod - def upgrade_schema(self) -> None: - raise NotImplementedError diff --git a/libs/agno/agno/storage/workflow/mongodb.py b/libs/agno/agno/storage/workflow/mongodb.py index 34250cb8db..27c0aa7165 100644 --- a/libs/agno/agno/storage/workflow/mongodb.py +++ b/libs/agno/agno/storage/workflow/mongodb.py @@ -1,233 +1 @@ -from datetime import datetime, timezone -from typing import List, Optional -from uuid import UUID - -try: - from pymongo import MongoClient - from pymongo.collection import Collection - from pymongo.database import Database - from pymongo.errors import PyMongoError -except ImportError: - raise ImportError("`pymongo` not installed. Please install it with `pip install pymongo`") - -from agno.storage.workflow.base import WorkflowStorage -from agno.storage.workflow.session import WorkflowSession -from agno.utils.log import logger - - -class MongoDbWorkflowStorage(WorkflowStorage): - def __init__( - self, - collection_name: str, - db_url: Optional[str] = None, - db_name: str = "agno", - client: Optional[MongoClient] = None, - ): - """ - This class provides workflow storage using MongoDB. - - Args: - collection_name: Name of the collection to store workflow sessions - db_url: MongoDB connection URL - db_name: Name of the database - client: Optional existing MongoDB client - schema_version: Version of the schema to use - auto_upgrade_schema: Whether to automatically upgrade the schema - """ - self._client: Optional[MongoClient] = client - if self._client is None and db_url is not None: - self._client = MongoClient(db_url) - elif self._client is None: - self._client = MongoClient() - - if self._client is None: - raise ValueError("Must provide either db_url or client") - - self.collection_name: str = collection_name - self.db_name: str = db_name - - self.db: Database = self._client[self.db_name] - self.collection: Collection = self.db[self.collection_name] - - def create(self) -> None: - """Create necessary indexes for the collection""" - try: - # Create indexes - self.collection.create_index("session_id", unique=True) - self.collection.create_index("user_id") - self.collection.create_index("workflow_id") - self.collection.create_index("created_at") - except PyMongoError as e: - logger.error(f"Error creating indexes: {e}") - raise - - def read(self, session_id: str, user_id: Optional[str] = None) -> Optional[WorkflowSession]: - """Read a workflow session from MongoDB - Args: - session_id: ID of the session to read - user_id: ID of the user to read - Returns: - WorkflowSession: The session if found, otherwise None - """ - try: - query = {"session_id": session_id} - if user_id: - query["user_id"] = user_id - - doc = self.collection.find_one(query) - if doc: - # Remove MongoDB _id before converting to WorkflowSession - doc.pop("_id", None) - return WorkflowSession.from_dict(doc) - return None - except PyMongoError as e: - logger.error(f"Error reading session: {e}") - return None - - def get_all_session_ids(self, user_id: Optional[str] = None, workflow_id: Optional[str] = None) -> List[str]: - """Get all session IDs matching the criteria - Args: - user_id: ID of the user to read - workflow_id: ID of the workflow to read - Returns: - List[str]: List of session IDs - """ - try: - query = {} - if user_id is not None: - query["user_id"] = user_id - if workflow_id is not None: - query["workflow_id"] = workflow_id - - cursor = self.collection.find(query, {"session_id": 1}).sort("created_at", -1) - - return [str(doc["session_id"]) for doc in cursor] - except PyMongoError as e: - logger.error(f"Error getting session IDs: {e}") - return [] - - def get_all_sessions( - self, user_id: Optional[str] = None, workflow_id: Optional[str] = None - ) -> List[WorkflowSession]: - """Get all sessions matching the criteria - Args: - user_id: ID of the user to read - workflow_id: ID of the workflow to read - Returns: - List[WorkflowSession]: List of sessions - """ - try: - query = {} - if user_id is not None: - query["user_id"] = user_id - if workflow_id is not None: - query["workflow_id"] = workflow_id - - cursor = self.collection.find(query).sort("created_at", -1) - sessions = [] - for doc in cursor: - # Remove MongoDB _id before converting to WorkflowSession - doc.pop("_id", None) - _workflow_session = WorkflowSession.from_dict(doc) - if _workflow_session is not None: - sessions.append(_workflow_session) - return sessions - except PyMongoError as e: - logger.error(f"Error getting sessions: {e}") - return [] - - def upsert(self, session: WorkflowSession, create_and_retry: bool = True) -> Optional[WorkflowSession]: - """Upsert a workflow session - Args: - session: WorkflowSession to upsert - create_and_retry: Whether to create a new session if the session_id already exists - Returns: - WorkflowSession: The session if upserted, otherwise None - """ - try: - # Convert session to dict and add timestamps - session_dict = session.to_dict() - now = datetime.now(timezone.utc) - timestamp = int(now.timestamp()) - - # Handle UUID serialization - if isinstance(session.session_id, UUID): - session_dict["session_id"] = str(session.session_id) - - # Add version field for optimistic locking - if "_version" not in session_dict: - session_dict["_version"] = 1 - else: - session_dict["_version"] += 1 - - update_data = {**session_dict, "updated_at": timestamp} - - # For new documents, set created_at - query = {"session_id": session_dict["session_id"]} - - doc = self.collection.find_one(query) - if not doc: - update_data["created_at"] = timestamp - - result = self.collection.update_one(query, {"$set": update_data}, upsert=True) - - if result.acknowledged: - return self.read(session_id=session_dict["session_id"]) - return None - - except PyMongoError as e: - logger.error(f"Error upserting session: {e}") - return None - - def delete_session(self, session_id: Optional[str] = None) -> None: - """Delete a workflow session - Args: - session_id: ID of the session to delete - Returns: - None - """ - if session_id is None: - logger.warning("No session_id provided for deletion") - return - - try: - result = self.collection.delete_one({"session_id": session_id}) - if result.deleted_count == 0: - logger.debug(f"No session found with session_id: {session_id}") - else: - logger.debug(f"Successfully deleted session with session_id: {session_id}") - except PyMongoError as e: - logger.error(f"Error deleting session: {e}") - - def drop(self) -> None: - """Drop the collection - Returns: - None - """ - try: - self.collection.drop() - except PyMongoError as e: - logger.error(f"Error dropping collection: {e}") - - def upgrade_schema(self) -> None: - """Placeholder for schema upgrades""" - pass - - def __deepcopy__(self, memo): - """Create a deep copy of the MongoDbWorkflowStorage instance""" - from copy import deepcopy - - # Create a new instance without calling __init__ - cls = self.__class__ - copied_obj = cls.__new__(cls) - memo[id(self)] = copied_obj - - # Deep copy attributes - for k, v in self.__dict__.items(): - if k in {"_client", "db", "collection"}: - # Reuse MongoDB connections without copying - setattr(copied_obj, k, v) - else: - setattr(copied_obj, k, deepcopy(v, memo)) - - return copied_obj +from agno.storage.mongodb import MongoDbStorage as MongoDbWorkflowStorage # noqa: F401 diff --git a/libs/agno/agno/storage/workflow/postgres.py b/libs/agno/agno/storage/workflow/postgres.py index 548685c38e..21c6e1f0a4 100644 --- a/libs/agno/agno/storage/workflow/postgres.py +++ b/libs/agno/agno/storage/workflow/postgres.py @@ -1,371 +1 @@ -import time -import traceback -from typing import List, Optional - -try: - from sqlalchemy import BigInteger, Column, Engine, Index, MetaData, String, Table, create_engine, inspect - from sqlalchemy.dialects import postgresql - from sqlalchemy.orm import scoped_session, sessionmaker - from sqlalchemy.sql.expression import select, text -except ImportError: - raise ImportError("`sqlalchemy` not installed. Please install it with `pip install sqlalchemy`") - -from agno.storage.workflow.base import WorkflowStorage -from agno.storage.workflow.session import WorkflowSession -from agno.utils.log import logger - - -class PostgresWorkflowStorage(WorkflowStorage): - def __init__( - self, - table_name: str, - schema: Optional[str] = "ai", - db_url: Optional[str] = None, - db_engine: Optional[Engine] = None, - schema_version: int = 1, - auto_upgrade_schema: bool = False, - ): - """ - This class provides workflow storage using a PostgreSQL database. - - The following order is used to determine the database connection: - 1. Use the db_engine if provided - 2. Use the db_url - 3. Raise an error if neither is provided - - Args: - table_name (str): The name of the table to store Workflow sessions. - schema (Optional[str]): The schema to use for the table. Defaults to "ai". - db_url (Optional[str]): The database URL to connect to. - db_engine (Optional[Engine]): The SQLAlchemy database engine to use. - schema_version (int): Version of the schema. Defaults to 1. - auto_upgrade_schema (bool): Whether to automatically upgrade the schema. - - Raises: - ValueError: If neither db_url nor db_engine is provided. - """ - _engine: Optional[Engine] = db_engine - if _engine is None and db_url is not None: - _engine = create_engine(db_url) - - if _engine is None: - raise ValueError("Must provide either db_url or db_engine") - - # Database attributes - self.table_name: str = table_name - self.schema: Optional[str] = schema - self.db_url: Optional[str] = db_url - self.db_engine: Engine = _engine - self.metadata: MetaData = MetaData(schema=self.schema) - self.inspector = inspect(self.db_engine) - - # Table schema version - self.schema_version: int = schema_version - # Automatically upgrade schema if True - self.auto_upgrade_schema: bool = auto_upgrade_schema - - # Database session - self.Session: scoped_session = scoped_session(sessionmaker(bind=self.db_engine)) - # Database table for storage - self.table: Table = self.get_table() - logger.debug(f"Created PostgresWorkflowStorage: '{self.schema}.{self.table_name}'") - - def get_table_v1(self) -> Table: - """ - Define the table schema for version 1. - - Returns: - Table: SQLAlchemy Table object representing the schema. - """ - table = Table( - self.table_name, - self.metadata, - # Session UUID: Primary Key - Column("session_id", String, primary_key=True), - # ID of the workflow that this session is associated with - Column("workflow_id", String), - # ID of the user interacting with this workflow - Column("user_id", String), - # Workflow Memory - Column("memory", postgresql.JSONB), - # Workflow Data - Column("workflow_data", postgresql.JSONB), - # Session Data - Column("session_data", postgresql.JSONB), - # Extra Data - Column("extra_data", postgresql.JSONB), - # The Unix timestamp of when this session was created. - Column("created_at", BigInteger, default=lambda: int(time.time())), - # The Unix timestamp of when this session was last updated. - Column("updated_at", BigInteger, onupdate=lambda: int(time.time())), - extend_existing=True, - ) - - # Add indexes - Index(f"idx_{self.table_name}_session_id", table.c.session_id) - Index(f"idx_{self.table_name}_workflow_id", table.c.workflow_id) - Index(f"idx_{self.table_name}_user_id", table.c.user_id) - - return table - - def get_table(self) -> Table: - """ - Get the table schema based on the schema version. - - Returns: - Table: SQLAlchemy Table object for the current schema version. - - Raises: - ValueError: If an unsupported schema version is specified. - """ - if self.schema_version == 1: - return self.get_table_v1() - else: - raise ValueError(f"Unsupported schema version: {self.schema_version}") - - def table_exists(self) -> bool: - """ - Check if the table exists in the database. - - Returns: - bool: True if the table exists, False otherwise. - """ - logger.debug(f"Checking if table exists: {self.table.name}") - try: - return self.inspector.has_table(self.table.name, schema=self.schema) - except Exception as e: - logger.error(f"Error checking if table exists: {e}") - return False - - def create(self) -> None: - """ - Create the table if it doesn't exist. - """ - if not self.table_exists(): - try: - with self.Session() as sess, sess.begin(): - if self.schema is not None: - logger.debug(f"Creating schema: {self.schema}") - sess.execute(text(f"CREATE SCHEMA IF NOT EXISTS {self.schema};")) - logger.debug(f"Creating table: {self.table_name}") - self.table.create(self.db_engine, checkfirst=True) - except Exception as e: - logger.error(f"Could not create table: '{self.table.fullname}': {e}") - - def read(self, session_id: str, user_id: Optional[str] = None) -> Optional[WorkflowSession]: - """ - Read a WorkflowSession from the database. - - Args: - session_id (str): The ID of the session to read. - user_id (Optional[str]): The ID of the user associated with the session. - - Returns: - Optional[WorkflowSession]: The WorkflowSession object if found, None otherwise. - """ - try: - with self.Session() as sess: - stmt = select(self.table).where(self.table.c.session_id == session_id) - if user_id: - stmt = stmt.where(self.table.c.user_id == user_id) - result = sess.execute(stmt).fetchone() - return WorkflowSession.from_dict(result._mapping) if result is not None else None # type: ignore - except Exception as e: - logger.debug(f"Exception reading from table: {e}") - logger.debug(f"Table does not exist: {self.table.name}") - logger.debug("Creating table for future transactions") - self.create() - return None - - def get_all_session_ids(self, user_id: Optional[str] = None, workflow_id: Optional[str] = None) -> List[str]: - """ - Get all session IDs, optionally filtered by user_id and/or workflow_id. - - Args: - user_id (Optional[str]): The ID of the user to filter by. - workflow_id (Optional[str]): The ID of the workflow to filter by. - - Returns: - List[str]: List of session IDs matching the criteria. - """ - try: - with self.Session() as sess, sess.begin(): - # get all session_ids - stmt = select(self.table.c.session_id) - if user_id is not None or user_id != "": - stmt = stmt.where(self.table.c.user_id == user_id) - if workflow_id is not None: - stmt = stmt.where(self.table.c.workflow_id == workflow_id) - # order by created_at desc - stmt = stmt.order_by(self.table.c.created_at.desc()) - # execute query - rows = sess.execute(stmt).fetchall() - return [row[0] for row in rows] if rows is not None else [] - except Exception as e: - logger.debug(f"Exception reading from table: {e}") - logger.debug(f"Table does not exist: {self.table.name}") - logger.debug("Creating table for future transactions") - self.create() - return [] - - def get_all_sessions( - self, user_id: Optional[str] = None, workflow_id: Optional[str] = None - ) -> List[WorkflowSession]: - """ - Get all sessions, optionally filtered by user_id and/or workflow_id. - - Args: - user_id (Optional[str]): The ID of the user to filter by. - workflow_id (Optional[str]): The ID of the workflow to filter by. - - Returns: - List[WorkflowSession]: List of AgentSession objects matching the criteria. - """ - try: - with self.Session() as sess, sess.begin(): - # get all sessions - stmt = select(self.table) - if user_id is not None and user_id != "": - stmt = stmt.where(self.table.c.user_id == user_id) - if workflow_id is not None: - stmt = stmt.where(self.table.c.workflow_id == workflow_id) - # order by created_at desc - stmt = stmt.order_by(self.table.c.created_at.desc()) - # execute query - rows = sess.execute(stmt).fetchall() - return [WorkflowSession.from_dict(row._mapping) for row in rows] if rows is not None else [] # type: ignore - except Exception as e: - logger.debug(f"Exception reading from table: {e}") - logger.debug(f"Table does not exist: {self.table.name}") - logger.debug("Creating table for future transactions") - self.create() - return [] - - def upsert(self, session: WorkflowSession, create_and_retry: bool = True) -> Optional[WorkflowSession]: - """ - Insert or update a WorkflowSession in the database. - - Args: - session (WorkflowSession): The WorkflowSession object to upsert. - create_and_retry (bool): Retry upsert if table does not exist. - - Returns: - Optional[WorkflowSession]: The upserted WorkflowSession object. - """ - try: - with self.Session() as sess, sess.begin(): - # Create an insert statement - stmt = postgresql.insert(self.table).values( - session_id=session.session_id, - workflow_id=session.workflow_id, - user_id=session.user_id, - memory=session.memory, - workflow_data=session.workflow_data, - session_data=session.session_data, - extra_data=session.extra_data, - ) - - # Define the upsert if the session_id already exists - # See: https://docs.sqlalchemy.org/en/20/dialects/postgresql.html#postgresql-insert-on-conflict - stmt = stmt.on_conflict_do_update( - index_elements=["session_id"], - set_=dict( - workflow_id=session.workflow_id, - user_id=session.user_id, - memory=session.memory, - workflow_data=session.workflow_data, - session_data=session.session_data, - extra_data=session.extra_data, - updated_at=int(time.time()), - ), # The updated value for each column - ) - - sess.execute(stmt) - except TypeError as e: - traceback.print_exc() - logger.error(f"Exception upserting into table: {e}") - return None - except Exception as e: - logger.debug(f"Exception upserting into table: {e}") - if create_and_retry and not self.table_exists(): - logger.debug(f"Table does not exist: {self.table.name}") - logger.debug("Creating table and retrying upsert") - self.create() - return self.upsert(session, create_and_retry=False) - return None - return self.read(session_id=session.session_id) - - def delete_session(self, session_id: Optional[str] = None): - """ - Delete a workflow session from the database. - - Args: - session_id (Optional[str]): The ID of the session to delete. - - Raises: - ValueError: If session_id is not provided. - """ - if session_id is None: - logger.warning("No session_id provided for deletion.") - return - - try: - with self.Session() as sess, sess.begin(): - # Delete the session with the given session_id - delete_stmt = self.table.delete().where(self.table.c.session_id == session_id) - result = sess.execute(delete_stmt) - if result.rowcount == 0: - logger.debug(f"No session found with session_id: {session_id}") - else: - logger.debug(f"Successfully deleted session with session_id: {session_id}") - except Exception as e: - logger.error(f"Error deleting session: {e}") - - def drop(self) -> None: - """ - Drop the table from the database if it exists. - """ - if self.table_exists(): - logger.debug(f"Deleting table: {self.table_name}") - self.table.drop(self.db_engine) - - def upgrade_schema(self) -> None: - """ - Upgrade the schema of the workflow storage table. - This method is currently a placeholder and does not perform any actions. - """ - pass - - def __deepcopy__(self, memo): - """ - Create a deep copy of the PostgresWorkflowStorage instance, handling unpickleable attributes. - - Args: - memo (dict): A dictionary of objects already copied during the current copying pass. - - Returns: - PostgresWorkflowStorage: A deep-copied instance of PostgresWorkflowStorage. - """ - from copy import deepcopy - - # Create a new instance without calling __init__ - cls = self.__class__ - copied_obj = cls.__new__(cls) - memo[id(self)] = copied_obj - - # Deep copy attributes - for k, v in self.__dict__.items(): - if k in {"metadata", "table", "inspector"}: - continue - # Reuse db_engine and Session without copying - elif k in {"db_engine", "Session"}: - setattr(copied_obj, k, v) - else: - setattr(copied_obj, k, deepcopy(v, memo)) - - # Recreate metadata and table for the copied instance - copied_obj.metadata = MetaData(schema=copied_obj.schema) - copied_obj.inspector = inspect(copied_obj.db_engine) - copied_obj.table = copied_obj.get_table() - - return copied_obj +from agno.storage.postgres import PostgresStorage as PostgresWorkflowStorage # noqa: F401 diff --git a/libs/agno/agno/storage/workflow/sqlite.py b/libs/agno/agno/storage/workflow/sqlite.py index 787ec966cc..01506944c9 100644 --- a/libs/agno/agno/storage/workflow/sqlite.py +++ b/libs/agno/agno/storage/workflow/sqlite.py @@ -1,364 +1 @@ -import time -import traceback -from pathlib import Path -from typing import List, Optional - -try: - from sqlalchemy.dialects import sqlite - from sqlalchemy.engine import Engine, create_engine - from sqlalchemy.inspection import inspect - from sqlalchemy.orm import Session, sessionmaker - from sqlalchemy.schema import Column, MetaData, Table - from sqlalchemy.sql.expression import select - from sqlalchemy.types import String -except ImportError: - raise ImportError("`sqlalchemy` not installed. Please install it using `pip install sqlalchemy`") - -from agno.storage.workflow.base import WorkflowStorage -from agno.storage.workflow.session import WorkflowSession -from agno.utils.log import logger - - -class SqliteWorkflowStorage(WorkflowStorage): - def __init__( - self, - table_name: str, - db_url: Optional[str] = None, - db_file: Optional[str] = None, - db_engine: Optional[Engine] = None, - schema_version: int = 1, - auto_upgrade_schema: bool = False, - ): - """ - This class provides workflow storage using a sqlite database. - - The following order is used to determine the database connection: - 1. Use the db_engine if provided - 2. Use the db_url - 3. Use the db_file - 4. Create a new in-memory database - - Args: - table_name: The name of the table to store Workflow sessions. - db_url: The database URL to connect to. - db_file: The database file to connect to. - db_engine: The SQLAlchemy database engine to use. - """ - _engine: Optional[Engine] = db_engine - if _engine is None and db_url is not None: - _engine = create_engine(db_url) - elif _engine is None and db_file is not None: - # Use the db_file to create the engine - db_path = Path(db_file).resolve() - # Ensure the directory exists - db_path.parent.mkdir(parents=True, exist_ok=True) - _engine = create_engine(f"sqlite:///{db_path}") - else: - _engine = create_engine("sqlite://") - - if _engine is None: - raise ValueError("Must provide either db_url, db_file or db_engine") - - # Database attributes - self.table_name: str = table_name - self.db_url: Optional[str] = db_url - self.db_engine: Engine = _engine - self.metadata: MetaData = MetaData() - self.inspector = inspect(self.db_engine) - - # Table schema version - self.schema_version: int = schema_version - # Automatically upgrade schema if True - self.auto_upgrade_schema: bool = auto_upgrade_schema - - # Database session - self.Session: sessionmaker[Session] = sessionmaker(bind=self.db_engine) - # Database table for storage - self.table: Table = self.get_table() - - def get_table_v1(self) -> Table: - """ - Define the table schema for version 1. - - Returns: - Table: SQLAlchemy Table object representing the schema. - """ - return Table( - self.table_name, - self.metadata, - # Session UUID: Primary Key - Column("session_id", String, primary_key=True), - # ID of the workflow that this session is associated with - Column("workflow_id", String), - # ID of the user interacting with this workflow - Column("user_id", String), - # Workflow Memory - Column("memory", sqlite.JSON), - # Workflow Data - Column("workflow_data", sqlite.JSON), - # Session Data - Column("session_data", sqlite.JSON), - # Extra Data - Column("extra_data", sqlite.JSON), - # The Unix timestamp of when this session was created. - Column("created_at", sqlite.INTEGER, default=lambda: int(time.time())), - # The Unix timestamp of when this session was last updated. - Column("updated_at", sqlite.INTEGER, onupdate=lambda: int(time.time())), - extend_existing=True, - sqlite_autoincrement=True, - ) - - def get_table(self) -> Table: - """ - Get the table schema based on the schema version. - - Returns: - Table: SQLAlchemy Table object for the current schema version. - - Raises: - ValueError: If an unsupported schema version is specified. - """ - if self.schema_version == 1: - return self.get_table_v1() - else: - raise ValueError(f"Unsupported schema version: {self.schema_version}") - - def table_exists(self) -> bool: - """ - Check if the table exists in the database. - - Returns: - bool: True if the table exists, False otherwise. - """ - logger.debug(f"Checking if table exists: {self.table.name}") - try: - return self.inspector.has_table(self.table.name) - except Exception as e: - logger.error(f"Error checking if table exists: {e}") - return False - - def create(self) -> None: - """ - Create the table if it doesn't exist. - """ - if not self.table_exists(): - logger.debug(f"Creating table: {self.table.name}") - self.table.create(self.db_engine, checkfirst=True) - - def read(self, session_id: str, user_id: Optional[str] = None) -> Optional[WorkflowSession]: - """ - Read a WorkflowSession from the database. - - Args: - session_id (str): The ID of the session to read. - user_id (Optional[str]): The ID of the user associated with the session. - - Returns: - Optional[WorkflowSession]: The WorkflowSession object if found, None otherwise. - """ - try: - with self.Session() as sess: - stmt = select(self.table).where(self.table.c.session_id == session_id) - if user_id: - stmt = stmt.where(self.table.c.user_id == user_id) - result = sess.execute(stmt).fetchone() - return WorkflowSession.from_dict(result._mapping) if result is not None else None # type: ignore - except Exception as e: - logger.debug(f"Exception reading from table: {e}") - logger.debug(f"Table does not exist: {self.table.name}") - logger.debug("Creating table for future transactions") - self.create() - return None - - def get_all_session_ids(self, user_id: Optional[str] = None, workflow_id: Optional[str] = None) -> List[str]: - """ - Get all session IDs, optionally filtered by user_id and/or workflow_id. - - Args: - user_id (Optional[str]): The ID of the user to filter by. - workflow_id (Optional[str]): The ID of the workflow to filter by. - - Returns: - List[str]: List of session IDs matching the criteria. - """ - try: - with self.Session() as sess, sess.begin(): - # get all session_ids - stmt = select(self.table.c.session_id) - if user_id is not None and user_id != "": - stmt = stmt.where(self.table.c.user_id == user_id) - if workflow_id is not None: - stmt = stmt.where(self.table.c.workflow_id == workflow_id) - # order by created_at desc - stmt = stmt.order_by(self.table.c.created_at.desc()) - # execute query - rows = sess.execute(stmt).fetchall() - return [row[0] for row in rows] if rows is not None else [] - except Exception as e: - logger.debug(f"Exception reading from table: {e}") - logger.debug(f"Table does not exist: {self.table.name}") - logger.debug("Creating table for future transactions") - self.create() - return [] - - def get_all_sessions( - self, user_id: Optional[str] = None, workflow_id: Optional[str] = None - ) -> List[WorkflowSession]: - """ - Get all sessions, optionally filtered by user_id and/or workflow_id. - - Args: - user_id (Optional[str]): The ID of the user to filter by. - workflow_id (Optional[str]): The ID of the workflow to filter by. - - Returns: - List[WorkflowSession]: List of AgentSession objects matching the criteria. - """ - try: - with self.Session() as sess, sess.begin(): - # get all sessions - stmt = select(self.table) - if user_id is not None and user_id != "": - stmt = stmt.where(self.table.c.user_id == user_id) - if workflow_id is not None: - stmt = stmt.where(self.table.c.workflow_id == workflow_id) - # order by created_at desc - stmt = stmt.order_by(self.table.c.created_at.desc()) - # execute query - rows = sess.execute(stmt).fetchall() - return [WorkflowSession.from_dict(row._mapping) for row in rows] if rows is not None else [] # type: ignore - except Exception as e: - logger.debug(f"Exception reading from table: {e}") - logger.debug(f"Table does not exist: {self.table.name}") - logger.debug("Creating table for future transactions") - self.create() - return [] - - def upsert(self, session: WorkflowSession, create_and_retry: bool = True) -> Optional[WorkflowSession]: - """ - Insert or update a WorkflowSession in the database. - - Args: - session (WorkflowSession): The WorkflowSession object to upsert. - create_and_retry (bool): Retry upsert if table does not exist. - - Returns: - Optional[WorkflowSession]: The upserted WorkflowSession object. - """ - try: - with self.Session() as sess, sess.begin(): - # Create an insert statement - stmt = sqlite.insert(self.table).values( - session_id=session.session_id, - workflow_id=session.workflow_id, - user_id=session.user_id, - memory=session.memory, - workflow_data=session.workflow_data, - session_data=session.session_data, - extra_data=session.extra_data, - ) - - # Define the upsert if the session_id already exists - # See: https://docs.sqlalchemy.org/en/20/dialects/sqlite.html#insert-on-conflict-upsert - stmt = stmt.on_conflict_do_update( - index_elements=["session_id"], - set_=dict( - workflow_id=session.workflow_id, - user_id=session.user_id, - memory=session.memory, - workflow_data=session.workflow_data, - session_data=session.session_data, - extra_data=session.extra_data, - updated_at=int(time.time()), - ), # The updated value for each column - ) - - sess.execute(stmt) - except TypeError as e: - traceback.print_exc() - logger.error(f"Exception upserting into table: {e}") - return None - except Exception as e: - logger.debug(f"Exception upserting into table: {e}") - if create_and_retry and not self.table_exists(): - logger.debug(f"Table does not exist: {self.table.name}") - logger.debug("Creating table and retrying upsert") - self.create() - return self.upsert(session, create_and_retry=False) - return None - return self.read(session_id=session.session_id) - - def delete_session(self, session_id: Optional[str] = None): - """ - Delete a workflow session from the database. - - Args: - session_id (Optional[str]): The ID of the session to delete. - - Raises: - ValueError: If session_id is not provided. - """ - if session_id is None: - logger.warning("No session_id provided for deletion.") - return - - try: - with self.Session() as sess, sess.begin(): - # Delete the session with the given session_id - delete_stmt = self.table.delete().where(self.table.c.session_id == session_id) - result = sess.execute(delete_stmt) - if result.rowcount == 0: - logger.debug(f"No session found with session_id: {session_id}") - else: - logger.debug(f"Successfully deleted session with session_id: {session_id}") - except Exception as e: - logger.error(f"Error deleting session: {e}") - - def drop(self) -> None: - """ - Drop the table from the database if it exists. - """ - if self.table_exists(): - logger.debug(f"Deleting table: {self.table_name}") - self.table.drop(self.db_engine) - - def upgrade_schema(self) -> None: - """ - Upgrade the schema of the workflow storage table. - This method is currently a placeholder and does not perform any actions. - """ - pass - - def __deepcopy__(self, memo): - """ - Create a deep copy of the SqliteWorkflowStorage instance, handling unpickleable attributes. - - Args: - memo (dict): A dictionary of objects already copied during the current copying pass. - - Returns: - SqliteWorkflowStorage: A deep-copied instance of SqliteWorkflowStorage. - """ - from copy import deepcopy - - # Create a new instance without calling __init__ - cls = self.__class__ - copied_obj = cls.__new__(cls) - memo[id(self)] = copied_obj - - # Deep copy attributes - for k, v in self.__dict__.items(): - if k in {"metadata", "table", "inspector"}: - continue - # Reuse db_engine and Session without copying - elif k in {"db_engine", "Session"}: - setattr(copied_obj, k, v) - else: - setattr(copied_obj, k, deepcopy(v, memo)) - - # Recreate metadata and table for the copied instance - copied_obj.metadata = MetaData() - copied_obj.inspector = inspect(copied_obj.db_engine) - copied_obj.table = copied_obj.get_table() - - return copied_obj +from agno.storage.sqlite import SqliteStorage as SqliteWorkflowStorage # noqa: F401 diff --git a/libs/agno/agno/storage/yaml.py b/libs/agno/agno/storage/yaml.py new file mode 100644 index 0000000000..16074b9e1e --- /dev/null +++ b/libs/agno/agno/storage/yaml.py @@ -0,0 +1,141 @@ +import time +from dataclasses import asdict +from pathlib import Path +from typing import List, Literal, Optional, Union + +import yaml + +from agno.storage.base import Storage +from agno.storage.session import Session +from agno.storage.session.agent import AgentSession +from agno.storage.session.workflow import WorkflowSession +from agno.utils.log import logger + + +class YamlStorage(Storage): + def __init__(self, dir_path: Union[str, Path], mode: Optional[Literal["agent", "workflow"]] = "agent"): + super().__init__(mode) + self.dir_path = Path(dir_path) + self.dir_path.mkdir(parents=True, exist_ok=True) + + def serialize(self, data: dict) -> str: + return yaml.dump(data, default_flow_style=False) + + def deserialize(self, data: str) -> dict: + return yaml.safe_load(data) + + def create(self) -> None: + """Create the storage if it doesn't exist.""" + if not self.dir_path.exists(): + self.dir_path.mkdir(parents=True, exist_ok=True) + + def read(self, session_id: str, user_id: Optional[str] = None) -> Optional[Session]: + """Read a Session from storage.""" + try: + with open(self.dir_path / f"{session_id}.yaml", "r", encoding="utf-8") as f: + data = self.deserialize(f.read()) + if user_id and data["user_id"] != user_id: + return None + if self.mode == "agent": + return AgentSession.from_dict(data) + elif self.mode == "workflow": + return WorkflowSession.from_dict(data) + except FileNotFoundError: + return None + + def get_all_session_ids(self, user_id: Optional[str] = None, entity_id: Optional[str] = None) -> List[str]: + """Get all session IDs, optionally filtered by user_id and/or entity_id.""" + session_ids = [] + for file in self.dir_path.glob("*.yaml"): + with open(file, "r", encoding="utf-8") as f: + data = self.deserialize(f.read()) + if user_id or entity_id: + if user_id and entity_id: + if self.mode == "agent" and data["agent_id"] == entity_id and data["user_id"] == user_id: + session_ids.append(data["session_id"]) + elif ( + self.mode == "workflow" and data["workflow_id"] == entity_id and data["user_id"] == user_id + ): + session_ids.append(data["session_id"]) + elif user_id and data["user_id"] == user_id: + session_ids.append(data["session_id"]) + elif entity_id: + if self.mode == "agent" and data["agent_id"] == entity_id: + session_ids.append(data["session_id"]) + elif self.mode == "workflow" and data["workflow_id"] == entity_id: + session_ids.append(data["session_id"]) + else: + # No filters applied, add all session_ids + session_ids.append(data["session_id"]) + return session_ids + + def get_all_sessions(self, user_id: Optional[str] = None, entity_id: Optional[str] = None) -> List[Session]: + """Get all sessions, optionally filtered by user_id and/or entity_id.""" + sessions: List[Session] = [] + for file in self.dir_path.glob("*.yaml"): + with open(file, "r", encoding="utf-8") as f: + data = self.deserialize(f.read()) + if user_id or entity_id: + _session: Optional[Session] = None + if user_id and entity_id: + if self.mode == "agent" and data["agent_id"] == entity_id and data["user_id"] == user_id: + _session = AgentSession.from_dict(data) + elif ( + self.mode == "workflow" and data["workflow_id"] == entity_id and data["user_id"] == user_id + ): + _session = WorkflowSession.from_dict(data) + elif user_id and data["user_id"] == user_id: + if self.mode == "agent": + _session = AgentSession.from_dict(data) + elif self.mode == "workflow": + _session = WorkflowSession.from_dict(data) + elif entity_id: + if self.mode == "agent" and data["agent_id"] == entity_id: + _session = AgentSession.from_dict(data) + elif self.mode == "workflow" and data["workflow_id"] == entity_id: + _session = WorkflowSession.from_dict(data) + + if _session: + sessions.append(_session) + else: + # No filters applied, add all sessions + if self.mode == "agent": + _session = AgentSession.from_dict(data) + elif self.mode == "workflow": + _session = WorkflowSession.from_dict(data) + if _session: + sessions.append(_session) + return sessions + + def upsert(self, session: Session) -> Optional[Session]: + """Insert or update an Session in storage.""" + try: + data = asdict(session) + data["updated_at"] = int(time.time()) + if "created_at" not in data: + data["created_at"] = data["updated_at"] + + with open(self.dir_path / f"{session.session_id}.yaml", "w", encoding="utf-8") as f: + f.write(self.serialize(data)) + return session + except Exception as e: + logger.error(f"Error upserting session: {e}") + return None + + def delete_session(self, session_id: Optional[str] = None): + """Delete a session from storage.""" + if session_id is None: + return + try: + (self.dir_path / f"{session_id}.yaml").unlink(missing_ok=True) + except Exception as e: + logger.error(f"Error deleting session: {e}") + + def drop(self) -> None: + """Drop all sessions from storage.""" + for file in self.dir_path.glob("*.yaml"): + file.unlink() + + def upgrade_schema(self) -> None: + """Upgrade the schema of the storage.""" + pass diff --git a/libs/agno/agno/utils/certs.py b/libs/agno/agno/utils/certs.py new file mode 100644 index 0000000000..6cd753a69b --- /dev/null +++ b/libs/agno/agno/utils/certs.py @@ -0,0 +1,27 @@ +from pathlib import Path + +import requests + + +def download_cert(cert_url: str, filename: str = "cert.pem"): + """ + Downloads a CA certificate bundle if it doesn't exist locally. + + Returns: + str: Path to the certificate file + """ + cert_dir = Path("./certs") + cert_path = cert_dir / filename + + # Create directory if it doesn't exist + cert_dir.mkdir(parents=True, exist_ok=True) + + # Download the certificate if it doesn't exist + if not cert_path.exists(): + response = requests.get(cert_url) + response.raise_for_status() + + with open(cert_path, "wb") as f: + f.write(response.content) + + return str(cert_path.absolute()) diff --git a/libs/agno/agno/workflow/__init__.py b/libs/agno/agno/workflow/__init__.py index b14b276c24..81482a2341 100644 --- a/libs/agno/agno/workflow/__init__.py +++ b/libs/agno/agno/workflow/__init__.py @@ -1 +1 @@ -from agno.workflow.workflow import RunEvent, RunResponse, Workflow, WorkflowSession, WorkflowStorage +from agno.workflow.workflow import RunEvent, RunResponse, Workflow, WorkflowSession # type: ignore diff --git a/libs/agno/agno/workflow/workflow.py b/libs/agno/agno/workflow/workflow.py index bd84f22876..6e2928a96e 100644 --- a/libs/agno/agno/workflow/workflow.py +++ b/libs/agno/agno/workflow/workflow.py @@ -14,8 +14,8 @@ from agno.media import AudioArtifact, ImageArtifact, VideoArtifact from agno.memory.workflow import WorkflowMemory, WorkflowRun from agno.run.response import RunEvent, RunResponse # noqa: F401 -from agno.storage.workflow.base import WorkflowStorage -from agno.storage.workflow.session import WorkflowSession +from agno.storage.base import Storage +from agno.storage.session.workflow import WorkflowSession from agno.utils.common import nested_model_dump from agno.utils.log import logger, set_log_level_to_debug, set_log_level_to_info from agno.utils.merge_dict import merge_dictionaries @@ -47,7 +47,7 @@ class Workflow: memory: Optional[WorkflowMemory] = None # --- Workflow Storage --- - storage: Optional[WorkflowStorage] = None + storage: Optional[Storage] = None # Extra data stored with this workflow extra_data: Optional[Dict[str, Any]] = None @@ -82,7 +82,7 @@ def __init__( session_name: Optional[str] = None, session_state: Optional[Dict[str, Any]] = None, memory: Optional[WorkflowMemory] = None, - storage: Optional[WorkflowStorage] = None, + storage: Optional[Storage] = None, extra_data: Optional[Dict[str, Any]] = None, debug_mode: bool = False, monitoring: bool = False, @@ -139,7 +139,8 @@ def run(self, **kwargs: Any): def run_workflow(self, **kwargs: Any): """Run the Workflow""" - # Set debug, workflow_id, session_id, initialize memory + # Set mode, debug, workflow_id, session_id, initialize memory + self.set_storage_mode() self.set_debug() self.set_workflow_id() self.set_session_id() @@ -218,6 +219,10 @@ def result_generator(): logger.warning(f"Workflow.run() should only return RunResponse objects, got: {type(result)}") return None + def set_storage_mode(self): + if self.storage is not None: + self.storage.mode = "workflow" + def set_workflow_id(self) -> str: if self.workflow_id is None: self.workflow_id = str(uuid4()) @@ -441,7 +446,7 @@ def read_from_storage(self) -> Optional[WorkflowSession]: Optional[WorkflowSession]: The loaded WorkflowSession or None if not found. """ if self.storage is not None and self.session_id is not None: - self.workflow_session = self.storage.read(session_id=self.session_id) + self.workflow_session = cast(WorkflowSession, self.storage.read(session_id=self.session_id)) if self.workflow_session is not None: self.load_workflow_session(session=self.workflow_session) return self.workflow_session @@ -453,7 +458,7 @@ def write_to_storage(self) -> Optional[WorkflowSession]: Optional[WorkflowSession]: The saved WorkflowSession or None if not saved. """ if self.storage is not None: - self.workflow_session = self.storage.upsert(session=self.get_workflow_session()) + self.workflow_session = cast(WorkflowSession, self.storage.upsert(session=self.get_workflow_session())) return self.workflow_session def load_session(self, force: bool = False) -> Optional[str]: @@ -570,7 +575,7 @@ def _deep_copy_field(self, field_name: str, field_value: Any) -> Any: return field_value.deep_copy() # For compound types, attempt a deep copy - if isinstance(field_value, (list, dict, set, WorkflowStorage)): + if isinstance(field_value, (list, dict, set, Storage)): try: return deepcopy(field_value) except Exception as e: diff --git a/libs/agno/pyproject.toml b/libs/agno/pyproject.toml index ad823eef15..bf83ea46a7 100644 --- a/libs/agno/pyproject.toml +++ b/libs/agno/pyproject.toml @@ -72,6 +72,7 @@ browserbase = ["browserbase", "playwright"] # Dependencies for Storage sql = ["sqlalchemy"] postgres = ["psycopg-binary", "psycopg"] +sqlite = ["sqlalchemy"] # Dependencies for Vector databases pgvector = ["pgvector"] @@ -132,6 +133,7 @@ tools = [ storage = [ "agno[sql]", "agno[postgres]", + "agno[sqlite]", ] # All vector databases diff --git a/libs/agno/tests/integration/models/anthropic/test_basic.py b/libs/agno/tests/integration/models/anthropic/test_basic.py index 5386974773..40f5b0e291 100644 --- a/libs/agno/tests/integration/models/anthropic/test_basic.py +++ b/libs/agno/tests/integration/models/anthropic/test_basic.py @@ -3,7 +3,7 @@ from agno.agent import Agent, RunResponse # noqa from agno.models.anthropic import Claude -from agno.storage.agent.sqlite import SqliteAgentStorage +from agno.storage.sqlite import SqliteStorage def _assert_metrics(response: RunResponse): @@ -120,7 +120,7 @@ class MovieScript(BaseModel): def test_history(): agent = Agent( model=Claude(id="claude-3-5-haiku-20241022"), - storage=SqliteAgentStorage(table_name="agent_sessions", db_file="tmp/agent_storage.db"), + storage=SqliteStorage(table_name="agent_sessions", db_file="tmp/agent_storage.db"), add_history_to_messages=True, telemetry=False, monitoring=False, diff --git a/libs/agno/tests/integration/models/aws/bedrock/test_basic.py b/libs/agno/tests/integration/models/aws/bedrock/test_basic.py index 8a810da826..4ae9b5b47b 100644 --- a/libs/agno/tests/integration/models/aws/bedrock/test_basic.py +++ b/libs/agno/tests/integration/models/aws/bedrock/test_basic.py @@ -2,7 +2,7 @@ from agno.agent import Agent, RunResponse # noqa from agno.models.aws import AwsBedrock -from agno.storage.agent.sqlite import SqliteAgentStorage +from agno.storage.sqlite import SqliteStorage def _assert_metrics(response: RunResponse): @@ -101,7 +101,7 @@ class MovieScript(BaseModel): def test_history(): agent = Agent( model=AwsBedrock(id="anthropic.claude-3-sonnet-20240229-v1:0"), - storage=SqliteAgentStorage(table_name="agent_sessions", db_file="tmp/agent_storage.db"), + storage=SqliteStorage(table_name="agent_sessions", db_file="tmp/agent_storage.db"), add_history_to_messages=True, telemetry=False, monitoring=False, diff --git a/libs/agno/tests/integration/models/aws/claude/test_basic.py b/libs/agno/tests/integration/models/aws/claude/test_basic.py index 3f5a782488..2b8d25d62d 100644 --- a/libs/agno/tests/integration/models/aws/claude/test_basic.py +++ b/libs/agno/tests/integration/models/aws/claude/test_basic.py @@ -3,7 +3,7 @@ from agno.agent import Agent, RunResponse # noqa from agno.models.aws import Claude -from agno.storage.agent.sqlite import SqliteAgentStorage +from agno.storage.sqlite import SqliteStorage def _assert_metrics(response: RunResponse): @@ -129,7 +129,7 @@ class MovieScript(BaseModel): def test_history(): agent = Agent( model=Claude(id="anthropic.claude-3-sonnet-20240229-v1:0"), - storage=SqliteAgentStorage(table_name="agent_sessions", db_file="tmp/agent_storage.db"), + storage=SqliteStorage(table_name="agent_sessions", db_file="tmp/agent_storage.db"), add_history_to_messages=True, telemetry=False, monitoring=False, diff --git a/libs/agno/tests/integration/models/azure/ai_foundry/test_basic.py b/libs/agno/tests/integration/models/azure/ai_foundry/test_basic.py index 05b9e8c752..142c2b8711 100644 --- a/libs/agno/tests/integration/models/azure/ai_foundry/test_basic.py +++ b/libs/agno/tests/integration/models/azure/ai_foundry/test_basic.py @@ -3,7 +3,7 @@ from agno.agent import Agent, RunResponse from agno.models.azure import AzureAIFoundry -from agno.storage.agent.sqlite import SqliteAgentStorage +from agno.storage.sqlite import SqliteStorage def _assert_metrics(response: RunResponse): @@ -129,7 +129,7 @@ class MovieScript(BaseModel): def test_history(): agent = Agent( model=AzureAIFoundry(id="Phi-4"), - storage=SqliteAgentStorage(table_name="agent_sessions", db_file="tmp/agent_storage.db"), + storage=SqliteStorage(table_name="agent_sessions", db_file="tmp/agent_storage.db"), add_history_to_messages=True, telemetry=False, monitoring=False, diff --git a/libs/agno/tests/integration/models/azure/openai/test_basic.py b/libs/agno/tests/integration/models/azure/openai/test_basic.py index 1ce145fb98..19a6d09ada 100644 --- a/libs/agno/tests/integration/models/azure/openai/test_basic.py +++ b/libs/agno/tests/integration/models/azure/openai/test_basic.py @@ -3,7 +3,7 @@ from agno.agent import Agent, RunResponse # noqa from agno.models.azure import AzureOpenAI -from agno.storage.agent.sqlite import SqliteAgentStorage +from agno.storage.sqlite import SqliteStorage def _assert_metrics(response: RunResponse): @@ -152,7 +152,7 @@ class MovieScript(BaseModel): def test_history(): agent = Agent( model=AzureOpenAI(id="gpt-4o-mini"), - storage=SqliteAgentStorage(table_name="agent_sessions", db_file="tmp/agent_storage.db"), + storage=SqliteStorage(table_name="agent_sessions", db_file="tmp/agent_storage.db"), add_history_to_messages=True, telemetry=False, monitoring=False, diff --git a/libs/agno/tests/integration/models/cohere/test_basic.py b/libs/agno/tests/integration/models/cohere/test_basic.py index ecfa8deae0..f2451bf78c 100644 --- a/libs/agno/tests/integration/models/cohere/test_basic.py +++ b/libs/agno/tests/integration/models/cohere/test_basic.py @@ -3,7 +3,7 @@ from agno.agent import Agent, RunResponse # noqa from agno.models.cohere import Cohere -from agno.storage.agent.sqlite import SqliteAgentStorage +from agno.storage.sqlite import SqliteStorage def _assert_metrics(response: RunResponse): @@ -118,7 +118,7 @@ class MovieScript(BaseModel): def test_history(): agent = Agent( model=Cohere(id="command"), - storage=SqliteAgentStorage(table_name="agent_sessions", db_file="tmp/agent_storage.db"), + storage=SqliteStorage(table_name="agent_sessions", db_file="tmp/agent_storage.db"), add_history_to_messages=True, telemetry=False, monitoring=False, diff --git a/libs/agno/tests/integration/models/deepinfra/__init__.py b/libs/agno/tests/integration/models/deepinfra/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/libs/agno/tests/integration/models/deepinfra/test_basic.py b/libs/agno/tests/integration/models/deepinfra/test_basic.py index eb969e4042..f262c38f49 100644 --- a/libs/agno/tests/integration/models/deepinfra/test_basic.py +++ b/libs/agno/tests/integration/models/deepinfra/test_basic.py @@ -3,7 +3,7 @@ from agno.agent import Agent, RunResponse from agno.models.deepinfra import DeepInfra -from agno.storage.agent.sqlite import SqliteAgentStorage +from agno.storage.sqlite import SqliteStorage def _assert_metrics(response: RunResponse): @@ -125,7 +125,7 @@ class MovieScript(BaseModel): def test_history(): agent = Agent( model=DeepInfra(id="meta-llama/Llama-2-70b-chat-hf"), - storage=SqliteAgentStorage(table_name="agent_sessions", db_file="tmp/agent_storage.db"), + storage=SqliteStorage(table_name="agent_sessions", db_file="tmp/agent_storage.db"), add_history_to_messages=True, telemetry=False, monitoring=False, diff --git a/libs/agno/tests/integration/models/deepseek/test_basic.py b/libs/agno/tests/integration/models/deepseek/test_basic.py index 8255452d95..234317265a 100644 --- a/libs/agno/tests/integration/models/deepseek/test_basic.py +++ b/libs/agno/tests/integration/models/deepseek/test_basic.py @@ -3,7 +3,7 @@ from agno.agent import Agent, RunResponse # noqa from agno.models.deepseek import DeepSeek -from agno.storage.agent.sqlite import SqliteAgentStorage +from agno.storage.sqlite import SqliteStorage def _assert_metrics(response: RunResponse): @@ -118,7 +118,7 @@ class MovieScript(BaseModel): def test_history(): agent = Agent( model=DeepSeek(id="deepseek-chat"), - storage=SqliteAgentStorage(table_name="agent_sessions", db_file="tmp/agent_storage.db"), + storage=SqliteStorage(table_name="agent_sessions", db_file="tmp/agent_storage.db"), add_history_to_messages=True, telemetry=False, monitoring=False, diff --git a/libs/agno/tests/integration/models/fireworks/test_basic.py b/libs/agno/tests/integration/models/fireworks/test_basic.py index eb85d92c5b..6e1a2d9de8 100644 --- a/libs/agno/tests/integration/models/fireworks/test_basic.py +++ b/libs/agno/tests/integration/models/fireworks/test_basic.py @@ -3,7 +3,7 @@ from agno.agent import Agent, RunResponse from agno.models.fireworks import Fireworks -from agno.storage.agent.sqlite import SqliteAgentStorage +from agno.storage.sqlite import SqliteStorage def _assert_metrics(response: RunResponse): @@ -143,7 +143,7 @@ class MovieScript(BaseModel): def test_history(): agent = Agent( model=Fireworks(id="accounts/fireworks/models/llama-v3p1-8b-instruct"), - storage=SqliteAgentStorage(table_name="agent_sessions", db_file="tmp/agent_storage.db"), + storage=SqliteStorage(table_name="agent_sessions", db_file="tmp/agent_storage.db"), add_history_to_messages=True, telemetry=False, monitoring=False, diff --git a/libs/agno/tests/integration/models/google/test_basic.py b/libs/agno/tests/integration/models/google/test_basic.py index 58282de262..834d862c41 100644 --- a/libs/agno/tests/integration/models/google/test_basic.py +++ b/libs/agno/tests/integration/models/google/test_basic.py @@ -10,7 +10,7 @@ from agno.memory.manager import MemoryManager from agno.memory.summarizer import MemorySummarizer from agno.models.google import Gemini -from agno.storage.agent.sqlite import SqliteAgentStorage +from agno.storage.sqlite import SqliteStorage from agno.tools.duckduckgo import DuckDuckGoTools @@ -145,7 +145,7 @@ def test_persistent_memory(): instructions=[ "You can search the internet with DuckDuckGo.", ], - storage=SqliteAgentStorage(table_name="chat_agent", db_file="tmp/agent_storage.db"), + storage=SqliteStorage(table_name="chat_agent", db_file="tmp/agent_storage.db"), # Adds the current date and time to the instructions add_datetime_to_instructions=True, # Adds the history of the conversation to the messages @@ -218,7 +218,7 @@ class MovieScript(BaseModel): def test_history(): agent = Agent( model=Gemini(id="gemini-1.5-flash"), - storage=SqliteAgentStorage(table_name="agent_sessions", db_file="tmp/agent_storage.db"), + storage=SqliteStorage(table_name="agent_sessions", db_file="tmp/agent_storage.db"), add_history_to_messages=True, telemetry=False, monitoring=False, diff --git a/libs/agno/tests/integration/models/groq/test_basic.py b/libs/agno/tests/integration/models/groq/test_basic.py index cc676ed771..886022836f 100644 --- a/libs/agno/tests/integration/models/groq/test_basic.py +++ b/libs/agno/tests/integration/models/groq/test_basic.py @@ -3,7 +3,7 @@ from agno.agent import Agent, RunResponse # noqa from agno.models.groq import Groq -from agno.storage.agent.sqlite import SqliteAgentStorage +from agno.storage.sqlite import SqliteStorage def _assert_metrics(response: RunResponse): @@ -129,7 +129,7 @@ class MovieScript(BaseModel): def test_history(): agent = Agent( model=Groq(id="mixtral-8x7b-32768"), - storage=SqliteAgentStorage(table_name="agent_sessions", db_file="tmp/agent_storage.db"), + storage=SqliteStorage(table_name="agent_sessions", db_file="tmp/agent_storage.db"), add_history_to_messages=True, telemetry=False, monitoring=False, diff --git a/libs/agno/tests/integration/models/huggingface/test_basic.py b/libs/agno/tests/integration/models/huggingface/test_basic.py index 68d95f3525..970477e19f 100644 --- a/libs/agno/tests/integration/models/huggingface/test_basic.py +++ b/libs/agno/tests/integration/models/huggingface/test_basic.py @@ -9,7 +9,7 @@ from agno.memory.manager import MemoryManager from agno.memory.summarizer import MemorySummarizer from agno.models.huggingface import HuggingFace -from agno.storage.agent.sqlite import SqliteAgentStorage +from agno.storage.sqlite import SqliteStorage from agno.tools.duckduckgo import DuckDuckGoTools @@ -145,7 +145,7 @@ class MovieScript(BaseModel): def test_history(): agent = Agent( model=HuggingFace(id="mistralai/Mistral-7B-Instruct-v0.2"), - storage=SqliteAgentStorage(table_name="agent_sessions", db_file="tmp/agent_storage.db"), + storage=SqliteStorage(table_name="agent_sessions", db_file="tmp/agent_storage.db"), add_history_to_messages=True, telemetry=False, monitoring=False, @@ -171,7 +171,7 @@ def test_persistent_memory(): instructions=[ "You can search the internet with DuckDuckGo.", ], - storage=SqliteAgentStorage(table_name="chat_agent", db_file="tmp/agent_storage.db"), + storage=SqliteStorage(table_name="chat_agent", db_file="tmp/agent_storage.db"), add_datetime_to_instructions=True, add_history_to_messages=True, num_history_responses=15, diff --git a/libs/agno/tests/integration/models/ibm/watsonx/test_basic.py b/libs/agno/tests/integration/models/ibm/watsonx/test_basic.py index 40a2f452b3..bba92b821f 100644 --- a/libs/agno/tests/integration/models/ibm/watsonx/test_basic.py +++ b/libs/agno/tests/integration/models/ibm/watsonx/test_basic.py @@ -3,7 +3,7 @@ from agno.agent import Agent, RunResponse # noqa from agno.models.ibm import WatsonX -from agno.storage.agent.sqlite import SqliteAgentStorage +from agno.storage.sqlite import SqliteStorage def _assert_metrics(response: RunResponse): @@ -124,7 +124,7 @@ class MovieScript(BaseModel): def test_history(): agent = Agent( model=WatsonX(id="ibm/granite-20b-code-instruct"), - storage=SqliteAgentStorage(table_name="agent_sessions", db_file="tmp/agent_storage.db"), + storage=SqliteStorage(table_name="agent_sessions", db_file="tmp/agent_storage.db"), add_history_to_messages=True, telemetry=False, monitoring=False, diff --git a/libs/agno/tests/integration/models/mistral/test_basic.py b/libs/agno/tests/integration/models/mistral/test_basic.py index f737916667..65f0926a2d 100644 --- a/libs/agno/tests/integration/models/mistral/test_basic.py +++ b/libs/agno/tests/integration/models/mistral/test_basic.py @@ -4,7 +4,7 @@ from agno.agent import Agent, RunResponse # noqa from agno.models.groq.groq import Groq from agno.models.mistral import MistralChat -from agno.storage.agent.sqlite import SqliteAgentStorage +from agno.storage.sqlite import SqliteStorage def _assert_metrics(response: RunResponse): @@ -149,7 +149,7 @@ class MovieScript(BaseModel): def test_history(): agent = Agent( model=MistralChat(id="mistral-small"), - storage=SqliteAgentStorage(table_name="agent_sessions", db_file="tmp/agent_storage.db"), + storage=SqliteStorage(table_name="agent_sessions", db_file="tmp/agent_storage.db"), add_history_to_messages=True, telemetry=False, monitoring=False, diff --git a/libs/agno/tests/integration/models/nvidia/test_basic.py b/libs/agno/tests/integration/models/nvidia/test_basic.py index 76cef466de..2f18d38b9e 100644 --- a/libs/agno/tests/integration/models/nvidia/test_basic.py +++ b/libs/agno/tests/integration/models/nvidia/test_basic.py @@ -3,7 +3,7 @@ from agno.agent import Agent, RunResponse from agno.models.nvidia import Nvidia -from agno.storage.agent.sqlite import SqliteAgentStorage +from agno.storage.sqlite import SqliteStorage def _assert_metrics(response: RunResponse): @@ -124,7 +124,7 @@ class MovieScript(BaseModel): def test_history(): agent = Agent( model=Nvidia(id="meta/llama-3.3-70b-instruct"), - storage=SqliteAgentStorage(table_name="agent_sessions", db_file="tmp/agent_storage.db"), + storage=SqliteStorage(table_name="agent_sessions", db_file="tmp/agent_storage.db"), add_history_to_messages=True, telemetry=False, monitoring=False, diff --git a/libs/agno/tests/integration/models/ollama/test_basic.py b/libs/agno/tests/integration/models/ollama/test_basic.py index ccdc4516aa..ea7d83f2d6 100644 --- a/libs/agno/tests/integration/models/ollama/test_basic.py +++ b/libs/agno/tests/integration/models/ollama/test_basic.py @@ -3,7 +3,7 @@ from agno.agent import Agent, RunResponse from agno.models.ollama import Ollama -from agno.storage.agent.sqlite import SqliteAgentStorage +from agno.storage.sqlite import SqliteStorage def _assert_metrics(response: RunResponse): @@ -124,7 +124,7 @@ class MovieScript(BaseModel): def test_history(): agent = Agent( model=Ollama(id="llama3.2:latest"), - storage=SqliteAgentStorage(table_name="agent_sessions", db_file="tmp/agent_storage.db"), + storage=SqliteStorage(table_name="agent_sessions", db_file="tmp/agent_storage.db"), add_history_to_messages=True, telemetry=False, monitoring=False, diff --git a/libs/agno/tests/integration/models/ollama_tools/test_basic.py b/libs/agno/tests/integration/models/ollama_tools/test_basic.py index 21538112ca..b16932b551 100644 --- a/libs/agno/tests/integration/models/ollama_tools/test_basic.py +++ b/libs/agno/tests/integration/models/ollama_tools/test_basic.py @@ -3,7 +3,7 @@ from agno.agent import Agent, RunResponse from agno.models.ollama import OllamaTools -from agno.storage.agent.sqlite import SqliteAgentStorage +from agno.storage.sqlite import SqliteStorage def _assert_metrics(response: RunResponse): @@ -110,7 +110,7 @@ class MovieScript(BaseModel): def test_history(): agent = Agent( model=OllamaTools(id="mistral"), - storage=SqliteAgentStorage(table_name="agent_sessions", db_file="tmp/agent_storage.db"), + storage=SqliteStorage(table_name="agent_sessions", db_file="tmp/agent_storage.db"), add_history_to_messages=True, telemetry=False, monitoring=False, diff --git a/libs/agno/tests/integration/models/openai/chat/__init__.py b/libs/agno/tests/integration/models/openai/chat/__init__.py new file mode 100644 index 0000000000..aabeb362b8 --- /dev/null +++ b/libs/agno/tests/integration/models/openai/chat/__init__.py @@ -0,0 +1 @@ +"""Integration tests for OpenAI Responses API.""" diff --git a/libs/agno/tests/integration/models/openai/chat/test_basic.py b/libs/agno/tests/integration/models/openai/chat/test_basic.py index a7210886be..75c36374ae 100644 --- a/libs/agno/tests/integration/models/openai/chat/test_basic.py +++ b/libs/agno/tests/integration/models/openai/chat/test_basic.py @@ -9,7 +9,7 @@ from agno.memory.manager import MemoryManager from agno.memory.summarizer import MemorySummarizer from agno.models.openai import OpenAIChat -from agno.storage.agent.sqlite import SqliteAgentStorage +from agno.storage.sqlite import SqliteStorage from agno.tools.duckduckgo import DuckDuckGoTools @@ -162,7 +162,7 @@ class MovieScript(BaseModel): def test_history(): agent = Agent( model=OpenAIChat(id="gpt-4o-mini"), - storage=SqliteAgentStorage(table_name="agent_sessions", db_file="tmp/agent_storage.db"), + storage=SqliteStorage(table_name="agent_sessions", db_file="tmp/agent_storage.db"), add_history_to_messages=True, telemetry=False, monitoring=False, @@ -188,7 +188,7 @@ def test_persistent_memory(): instructions=[ "You can search the internet with DuckDuckGo.", ], - storage=SqliteAgentStorage(table_name="chat_agent", db_file="tmp/agent_storage.db"), + storage=SqliteStorage(table_name="chat_agent", db_file="tmp/agent_storage.db"), # Adds the current date and time to the instructions add_datetime_to_instructions=True, # Adds the history of the conversation to the messages diff --git a/libs/agno/tests/integration/models/openrouter/test_basic.py b/libs/agno/tests/integration/models/openrouter/test_basic.py index 983bdcfc00..cea49f5feb 100644 --- a/libs/agno/tests/integration/models/openrouter/test_basic.py +++ b/libs/agno/tests/integration/models/openrouter/test_basic.py @@ -3,7 +3,7 @@ from agno.agent import Agent, RunResponse from agno.models.openrouter import OpenRouter -from agno.storage.agent.sqlite import SqliteAgentStorage +from agno.storage.sqlite import SqliteStorage def _assert_metrics(response: RunResponse): @@ -123,7 +123,7 @@ class MovieScript(BaseModel): def test_history(): agent = Agent( model=OpenRouter(id="anthropic/claude-3-sonnet"), - storage=SqliteAgentStorage(table_name="agent_sessions", db_file="tmp/agent_storage.db"), + storage=SqliteStorage(table_name="agent_sessions", db_file="tmp/agent_storage.db"), add_history_to_messages=True, telemetry=False, monitoring=False, diff --git a/libs/agno/tests/integration/models/perplexity/test_basic.py b/libs/agno/tests/integration/models/perplexity/test_basic.py index 83aa6c1835..03efcee8f2 100644 --- a/libs/agno/tests/integration/models/perplexity/test_basic.py +++ b/libs/agno/tests/integration/models/perplexity/test_basic.py @@ -3,7 +3,7 @@ from agno.agent import Agent, RunResponse from agno.models.perplexity import Perplexity -from agno.storage.agent.sqlite import SqliteAgentStorage +from agno.storage.sqlite import SqliteStorage def _assert_metrics(response: RunResponse): @@ -124,7 +124,7 @@ class MovieScript(BaseModel): def test_history(): agent = Agent( model=Perplexity(id="sonar"), - storage=SqliteAgentStorage(table_name="agent_sessions", db_file="tmp/agent_storage.db"), + storage=SqliteStorage(table_name="agent_sessions", db_file="tmp/agent_storage.db"), add_history_to_messages=True, telemetry=False, monitoring=False, diff --git a/libs/agno/tests/integration/models/sambanova/test_basic.py b/libs/agno/tests/integration/models/sambanova/test_basic.py index 6b965246a0..072cef3c25 100644 --- a/libs/agno/tests/integration/models/sambanova/test_basic.py +++ b/libs/agno/tests/integration/models/sambanova/test_basic.py @@ -3,7 +3,7 @@ from agno.agent import Agent, RunResponse from agno.models.sambanova import Sambanova -from agno.storage.agent.sqlite import SqliteAgentStorage +from agno.storage.sqlite import SqliteStorage def _assert_metrics(response: RunResponse): @@ -123,7 +123,7 @@ class MovieScript(BaseModel): def test_history(): agent = Agent( model=Sambanova(id="Meta-Llama-3.1-8B-Instruct"), - storage=SqliteAgentStorage(table_name="agent_sessions", db_file="tmp/agent_storage.db"), + storage=SqliteStorage(table_name="agent_sessions", db_file="tmp/agent_storage.db"), add_history_to_messages=True, telemetry=False, monitoring=False, diff --git a/libs/agno/tests/integration/models/together/test_basic.py b/libs/agno/tests/integration/models/together/test_basic.py index 530cf5f98e..1c73bbf883 100644 --- a/libs/agno/tests/integration/models/together/test_basic.py +++ b/libs/agno/tests/integration/models/together/test_basic.py @@ -3,7 +3,7 @@ from agno.agent import Agent, RunResponse from agno.models.together import Together -from agno.storage.agent.sqlite import SqliteAgentStorage +from agno.storage.sqlite import SqliteStorage def _assert_metrics(response: RunResponse): @@ -143,7 +143,7 @@ class MovieScript(BaseModel): def test_history(): agent = Agent( model=Together(id="meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo"), - storage=SqliteAgentStorage(table_name="agent_sessions", db_file="tmp/agent_storage.db"), + storage=SqliteStorage(table_name="agent_sessions", db_file="tmp/agent_storage.db"), add_history_to_messages=True, telemetry=False, monitoring=False, diff --git a/libs/agno/tests/integration/models/xai/test_basic.py b/libs/agno/tests/integration/models/xai/test_basic.py index 9e798fb7fe..492082d284 100644 --- a/libs/agno/tests/integration/models/xai/test_basic.py +++ b/libs/agno/tests/integration/models/xai/test_basic.py @@ -3,7 +3,7 @@ from agno.agent import Agent, RunResponse from agno.models.xai import xAI -from agno.storage.agent.sqlite import SqliteAgentStorage +from agno.storage.sqlite import SqliteStorage def _assert_metrics(response: RunResponse): @@ -122,7 +122,7 @@ class MovieScript(BaseModel): def test_history(): agent = Agent( model=xAI(id="grok-beta"), - storage=SqliteAgentStorage(table_name="agent_sessions", db_file="tmp/agent_storage.db"), + storage=SqliteStorage(table_name="agent_sessions", db_file="tmp/agent_storage.db"), add_history_to_messages=True, telemetry=False, monitoring=False, diff --git a/libs/agno/tests/integration/storage/__init__.py b/libs/agno/tests/integration/storage/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/libs/agno/tests/integration/storage/test_json_storage_agent.py b/libs/agno/tests/integration/storage/test_json_storage_agent.py new file mode 100644 index 0000000000..dd388c7c5b --- /dev/null +++ b/libs/agno/tests/integration/storage/test_json_storage_agent.py @@ -0,0 +1,151 @@ +import shutil + +import pytest + +from agno.agent import Agent +from agno.storage.json import JsonStorage +from agno.storage.session.agent import AgentSession + + +@pytest.fixture +def temp_storage_path(tmp_path): + """Create a temporary directory for storage that's cleaned up after tests.""" + storage_dir = tmp_path / "test_storage" + storage_dir.mkdir() + yield storage_dir + shutil.rmtree(storage_dir) + + +@pytest.fixture +def agent_storage(temp_storage_path): + """Create a JsonStorage instance for agent sessions.""" + return JsonStorage(dir_path=temp_storage_path, mode="agent") + + +@pytest.fixture +def workflow_storage(temp_storage_path): + """Create a JsonStorage instance for workflow sessions.""" + return JsonStorage(dir_path=temp_storage_path / "workflows", mode="workflow") + + +@pytest.fixture +def agent_with_storage(agent_storage): + """Create an agent with the test storage.""" + return Agent(storage=agent_storage, add_history_to_messages=True) + + +def test_storage_creation(temp_storage_path): + """Test that storage directory is created.""" + JsonStorage(dir_path=temp_storage_path) + assert temp_storage_path.exists() + assert temp_storage_path.is_dir() + + +def test_agent_session_storage(agent_with_storage, agent_storage): + """Test that agent sessions are properly stored.""" + # Run agent and get response + agent_with_storage.run("What is the capital of France?") + + # Get the session ID from the agent + session_id = agent_with_storage.session_id + + # Verify session was stored + stored_session = agent_storage.read(session_id) + assert stored_session is not None + assert isinstance(stored_session, AgentSession) + assert stored_session.session_id == session_id + + # Verify session contains the interaction + assert len(stored_session.memory["messages"]) > 0 + + +def test_multiple_interactions(agent_with_storage, agent_storage): + """Test that multiple interactions are properly stored in the same session.""" + # First interaction + agent_with_storage.run("What is the capital of France?") + session_id = agent_with_storage.session_id + + # Second interaction + agent_with_storage.run("What is its population?") + + # Verify both interactions are in the same session + stored_session = agent_storage.read(session_id) + assert stored_session is not None + assert len(stored_session.memory["messages"]) >= 4 # Should have at least 4 messages (2 questions + 2 responses) + + +def test_session_retrieval_by_user(agent_with_storage, agent_storage): + """Test retrieving sessions filtered by user ID.""" + # Create a session with a specific user ID + agent_with_storage.user_id = "test_user" + agent_with_storage.run("What is the capital of France?") + + # Get all sessions for the user + sessions = agent_storage.get_all_sessions(user_id="test_user") + assert len(sessions) == 1 + assert sessions[0].user_id == "test_user" + + # Verify no sessions for different user + other_sessions = agent_storage.get_all_sessions(user_id="other_user") + assert len(other_sessions) == 0 + + +def test_session_deletion(agent_with_storage, agent_storage): + """Test deleting a session.""" + # Create a session + agent_with_storage.run("What is the capital of France?") + session_id = agent_with_storage.session_id + + # Verify session exists + assert agent_storage.read(session_id) is not None + + # Delete session + agent_storage.delete_session(session_id) + + # Verify session was deleted + assert agent_storage.read(session_id) is None + + +def test_get_all_session_ids(agent_storage): + """Test retrieving all session IDs.""" + # Create multiple sessions with different user IDs and agent IDs + agent_1 = Agent(storage=agent_storage, user_id="user1", agent_id="agent1", add_history_to_messages=True) + agent_2 = Agent(storage=agent_storage, user_id="user1", agent_id="agent2", add_history_to_messages=True) + agent_3 = Agent(storage=agent_storage, user_id="user2", agent_id="agent3", add_history_to_messages=True) + + agent_1.run("Question 1") + agent_2.run("Question 2") + agent_3.run("Question 3") + + # Get all session IDs + all_sessions = agent_storage.get_all_session_ids() + assert len(all_sessions) == 3 + + # Filter by user ID + user1_sessions = agent_storage.get_all_session_ids(user_id="user1") + assert len(user1_sessions) == 2 + + # Filter by agent ID + agent1_sessions = agent_storage.get_all_session_ids(entity_id="agent1") + assert len(agent1_sessions) == 1 + + # Filter by both + filtered_sessions = agent_storage.get_all_session_ids(user_id="user1", entity_id="agent2") + assert len(filtered_sessions) == 1 + + +def test_drop_storage(agent_with_storage, agent_storage): + """Test dropping all sessions from storage.""" + # Create a few sessions + for i in range(3): + agent = Agent(storage=agent_storage, add_history_to_messages=True) + agent.run(f"Question {i}") + + # Verify sessions exist + assert len(agent_storage.get_all_session_ids()) == 3 + + # Drop all sessions + agent_storage.drop() + + # Verify no sessions remain + assert len(agent_storage.get_all_session_ids()) == 0 diff --git a/libs/agno/tests/integration/storage/test_json_storage_workflow.py b/libs/agno/tests/integration/storage/test_json_storage_workflow.py new file mode 100644 index 0000000000..48a3d0bf9e --- /dev/null +++ b/libs/agno/tests/integration/storage/test_json_storage_workflow.py @@ -0,0 +1,205 @@ +import shutil + +import pytest + +from agno.agent import Agent +from agno.run.response import RunResponse +from agno.storage.json import JsonStorage +from agno.storage.session.workflow import WorkflowSession +from agno.workflow import Workflow + + +@pytest.fixture +def temp_storage_path(tmp_path): + """Create a temporary directory for storage that's cleaned up after tests.""" + storage_dir = tmp_path / "test_storage" + storage_dir.mkdir() + yield storage_dir + shutil.rmtree(storage_dir) + + +@pytest.fixture +def workflow_storage(temp_storage_path): + """Create a JsonStorage instance for workflow sessions.""" + return JsonStorage(dir_path=temp_storage_path) + + +class SimpleWorkflow(Workflow): + """A simple workflow with a single agent for testing.""" + + description: str = "A simple workflow for testing storage" + + test_agent: Agent = Agent( + description="A test agent for the workflow", + ) + + def run(self, query: str) -> RunResponse: + """Run the workflow with a simple query.""" + response = self.test_agent.run(query) + return RunResponse(run_id=self.run_id, content=f"Workflow processed: {response.content}") + + +@pytest.fixture +def workflow_with_storage(workflow_storage): + """Create a workflow with the test storage.""" + return SimpleWorkflow(storage=workflow_storage, name="TestWorkflow") + + +def test_storage_creation(temp_storage_path): + """Test that storage directory is created.""" + JsonStorage(dir_path=temp_storage_path, mode="workflow") + assert temp_storage_path.exists() + assert temp_storage_path.is_dir() + + +def test_workflow_session_storage(workflow_with_storage, workflow_storage): + """Test that workflow sessions are properly stored.""" + # Run workflow and get response + workflow_with_storage.run(query="What is the capital of France?") + + assert workflow_with_storage.storage.mode == "workflow" + + # Get the session ID from the workflow + session_id = workflow_with_storage.session_id + + # Verify session was stored + stored_session = workflow_storage.read(session_id) + assert stored_session is not None + assert isinstance(stored_session, WorkflowSession) + assert stored_session.session_id == session_id + + # Verify workflow data was stored + assert stored_session.workflow_data is not None + assert stored_session.workflow_data.get("name") == "TestWorkflow" + + # Verify memory contains the run + assert stored_session.memory is not None + assert "runs" in stored_session.memory + assert len(stored_session.memory["runs"]) > 0 + + +def test_multiple_interactions(workflow_with_storage, workflow_storage): + """Test that multiple interactions are properly stored in the same session.""" + # First interaction + workflow_with_storage.run(query="What is the capital of France?") + + assert workflow_with_storage.storage.mode == "workflow" + session_id = workflow_with_storage.session_id + + # Second interaction + workflow_with_storage.run(query="What is its population?") + + # Verify both interactions are in the same session + stored_session = workflow_storage.read(session_id) + assert stored_session is not None + assert "runs" in stored_session.memory + assert len(stored_session.memory["runs"]) == 2 # Should have 2 runs + + +def test_session_retrieval_by_user(workflow_storage): + """Test retrieving sessions filtered by user ID.""" + # Create a session with a specific user ID + workflow = SimpleWorkflow(storage=workflow_storage, user_id="test_user", name="UserTestWorkflow") + workflow.run(query="What is the capital of France?") + + assert workflow.storage.mode == "workflow" + + # Get all sessions for the user + sessions = workflow_storage.get_all_sessions(user_id="test_user") + assert len(sessions) == 1 + assert sessions[0].user_id == "test_user" + + # Verify no sessions for different user + other_sessions = workflow_storage.get_all_sessions(user_id="other_user") + assert len(other_sessions) == 0 + + +def test_session_deletion(workflow_with_storage, workflow_storage): + """Test deleting a session.""" + # Create a session + workflow_with_storage.run(query="What is the capital of France?") + session_id = workflow_with_storage.session_id + + # Verify session exists + assert workflow_storage.read(session_id) is not None + + # Delete session + workflow_storage.delete_session(session_id) + + # Verify session was deleted + assert workflow_storage.read(session_id) is None + + +def test_get_all_session_ids(workflow_storage): + """Test retrieving all session IDs.""" + # Create multiple sessions with different user IDs and workflow IDs + workflow_1 = SimpleWorkflow(storage=workflow_storage, user_id="user1", workflow_id="workflow1", name="Workflow1") + workflow_2 = SimpleWorkflow(storage=workflow_storage, user_id="user1", workflow_id="workflow2", name="Workflow2") + workflow_3 = SimpleWorkflow(storage=workflow_storage, user_id="user2", workflow_id="workflow3", name="Workflow3") + + workflow_1.run(query="Question 1") + workflow_2.run(query="Question 2") + workflow_3.run(query="Question 3") + + # Get all session IDs + all_sessions = workflow_storage.get_all_session_ids() + assert len(all_sessions) == 3 + + # Filter by user ID + user1_sessions = workflow_storage.get_all_session_ids(user_id="user1") + assert len(user1_sessions) == 2 + + # Filter by workflow ID + workflow1_sessions = workflow_storage.get_all_session_ids(entity_id="workflow1") + assert len(workflow1_sessions) == 1 + + # Filter by both + filtered_sessions = workflow_storage.get_all_session_ids(user_id="user1", entity_id="workflow2") + assert len(filtered_sessions) == 1 + + +def test_drop_storage(workflow_storage): + """Test dropping all sessions from storage.""" + # Create a few sessions + for i in range(3): + workflow = SimpleWorkflow(storage=workflow_storage, name=f"Workflow{i}") + workflow.run(query=f"Question {i}") + + # Verify sessions exist + assert len(workflow_storage.get_all_session_ids()) == 3 + + # Drop all sessions + workflow_storage.drop() + + # Verify no sessions remain + assert len(workflow_storage.get_all_session_ids()) == 0 + + +def test_workflow_session_rename(workflow_with_storage, workflow_storage): + """Test renaming a workflow session.""" + # Create a session + workflow_with_storage.run(query="What is the capital of France?") + session_id = workflow_with_storage.session_id + + # Rename the session + workflow_with_storage.rename_session("My Renamed Session") + + # Verify session was renamed + stored_session = workflow_storage.read(session_id) + assert stored_session is not None + assert stored_session.session_data.get("session_name") == "My Renamed Session" + + +def test_workflow_rename(workflow_with_storage, workflow_storage): + """Test renaming a workflow.""" + # Create a session + workflow_with_storage.run(query="What is the capital of France?") + session_id = workflow_with_storage.session_id + + # Rename the workflow + workflow_with_storage.rename("My Renamed Workflow") + + # Verify workflow was renamed + stored_session = workflow_storage.read(session_id) + assert stored_session is not None + assert stored_session.workflow_data.get("name") == "My Renamed Workflow" diff --git a/libs/agno/tests/integration/storage/test_sqlite_storage_agent.py b/libs/agno/tests/integration/storage/test_sqlite_storage_agent.py new file mode 100644 index 0000000000..65210f69ec --- /dev/null +++ b/libs/agno/tests/integration/storage/test_sqlite_storage_agent.py @@ -0,0 +1,152 @@ +import os +import uuid + +import pytest + +from agno.agent import Agent +from agno.storage.session.agent import AgentSession +from agno.storage.sqlite import SqliteStorage + + +@pytest.fixture +def temp_db_path(tmp_path): + """Create a temporary database file path.""" + db_file = tmp_path / "test_agent_storage.db" + yield str(db_file) + # Clean up the file after tests + if os.path.exists(db_file): + os.remove(db_file) + + +@pytest.fixture +def agent_storage(temp_db_path): + """Create a SqliteStorage instance for agent sessions.""" + # Use a unique table name for each test run + table_name = f"agent_sessions_{uuid.uuid4().hex[:8]}" + storage = SqliteStorage(table_name=table_name, db_file=temp_db_path, mode="agent") + storage.create() + return storage + + +@pytest.fixture +def agent_with_storage(agent_storage): + """Create an agent with the test storage.""" + return Agent(storage=agent_storage, add_history_to_messages=True) + + +def test_storage_creation(temp_db_path): + """Test that storage is created correctly.""" + storage = SqliteStorage(table_name="agent_sessions", db_file=temp_db_path) + storage.create() + assert os.path.exists(temp_db_path) + assert storage.table_exists() + + +def test_agent_session_storage(agent_with_storage, agent_storage): + """Test that agent sessions are properly stored.""" + # Run agent and get response + agent_with_storage.run("What is the capital of France?") + + # Get the session ID from the agent + session_id = agent_with_storage.session_id + + # Verify session was stored + stored_session = agent_storage.read(session_id) + assert stored_session is not None + assert isinstance(stored_session, AgentSession) + assert stored_session.session_id == session_id + + # Verify session contains the interaction + assert len(stored_session.memory["messages"]) > 0 + + +def test_multiple_interactions(agent_with_storage, agent_storage): + """Test that multiple interactions are properly stored in the same session.""" + # First interaction + agent_with_storage.run("What is the capital of France?") + session_id = agent_with_storage.session_id + + # Second interaction + agent_with_storage.run("What is its population?") + + # Verify both interactions are in the same session + stored_session = agent_storage.read(session_id) + assert stored_session is not None + assert len(stored_session.memory["messages"]) >= 4 # Should have at least 4 messages (2 questions + 2 responses) + + +def test_session_retrieval_by_user(agent_with_storage, agent_storage): + """Test retrieving sessions filtered by user ID.""" + # Create a session with a specific user ID + agent_with_storage.user_id = "test_user" + agent_with_storage.run("What is the capital of France?") + + # Get all sessions for the user + sessions = agent_storage.get_all_sessions(user_id="test_user") + assert len(sessions) == 1 + assert sessions[0].user_id == "test_user" + + # Verify no sessions for different user + other_sessions = agent_storage.get_all_sessions(user_id="other_user") + assert len(other_sessions) == 0 + + +def test_session_deletion(agent_with_storage, agent_storage): + """Test deleting a session.""" + # Create a session + agent_with_storage.run("What is the capital of France?") + session_id = agent_with_storage.session_id + + # Verify session exists + assert agent_storage.read(session_id) is not None + + # Delete session + agent_storage.delete_session(session_id) + + # Verify session was deleted + assert agent_storage.read(session_id) is None + + +def test_get_all_session_ids(agent_storage): + """Test retrieving all session IDs.""" + # Create multiple sessions with different user IDs and agent IDs + agent_1 = Agent(storage=agent_storage, user_id="user1", agent_id="agent1", add_history_to_messages=True) + agent_2 = Agent(storage=agent_storage, user_id="user1", agent_id="agent2", add_history_to_messages=True) + agent_3 = Agent(storage=agent_storage, user_id="user2", agent_id="agent3", add_history_to_messages=True) + + agent_1.run("Question 1") + agent_2.run("Question 2") + agent_3.run("Question 3") + + # Get all session IDs + all_sessions = agent_storage.get_all_session_ids() + assert len(all_sessions) == 3 + + # Filter by user ID + user1_sessions = agent_storage.get_all_session_ids(user_id="user1") + assert len(user1_sessions) == 2 + + # Filter by agent ID + agent1_sessions = agent_storage.get_all_session_ids(entity_id="agent1") + assert len(agent1_sessions) == 1 + + # Filter by both + filtered_sessions = agent_storage.get_all_session_ids(user_id="user1", entity_id="agent2") + assert len(filtered_sessions) == 1 + + +def test_drop_storage(agent_with_storage, agent_storage): + """Test dropping all sessions from storage.""" + # Create a few sessions + for i in range(3): + agent = Agent(storage=agent_storage, add_history_to_messages=True) + agent.run(f"Question {i}") + + # Verify sessions exist + assert len(agent_storage.get_all_session_ids()) == 3 + + # Drop all sessions + agent_storage.drop() + + # Verify no sessions remain + assert len(agent_storage.get_all_session_ids()) == 0 diff --git a/libs/agno/tests/integration/storage/test_sqlite_storage_workflow.py b/libs/agno/tests/integration/storage/test_sqlite_storage_workflow.py new file mode 100644 index 0000000000..0e0b1d661a --- /dev/null +++ b/libs/agno/tests/integration/storage/test_sqlite_storage_workflow.py @@ -0,0 +1,209 @@ +import os + +import pytest + +from agno.agent import Agent +from agno.run.response import RunResponse +from agno.storage.session.workflow import WorkflowSession +from agno.storage.sqlite import SqliteStorage +from agno.workflow import Workflow + + +@pytest.fixture +def temp_db_path(tmp_path): + """Create a temporary database file path.""" + db_file = tmp_path / "test_workflow_storage.db" + yield str(db_file) + # Clean up the file after tests + if os.path.exists(db_file): + os.remove(db_file) + + +@pytest.fixture +def workflow_storage(temp_db_path): + """Create a SqliteStorage instance for workflow sessions.""" + storage = SqliteStorage(table_name="workflow_sessions", db_file=temp_db_path, mode="workflow") + storage.create() + return storage + + +class SimpleWorkflow(Workflow): + """A simple workflow with a single agent for testing.""" + + description: str = "A simple workflow for testing storage" + + test_agent: Agent = Agent( + description="A test agent for the workflow", + ) + + def run(self, query: str) -> RunResponse: + """Run the workflow with a simple query.""" + response = self.test_agent.run(query) + return RunResponse(run_id=self.run_id, content=f"Workflow processed: {response.content}") + + +@pytest.fixture +def workflow_with_storage(workflow_storage): + """Create a workflow with the test storage.""" + return SimpleWorkflow(storage=workflow_storage, name="TestWorkflow") + + +def test_storage_creation(temp_db_path): + """Test that storage is created correctly.""" + storage = SqliteStorage(table_name="workflow_sessions", db_file=temp_db_path, mode="workflow") + storage.create() + assert os.path.exists(temp_db_path) + assert storage.table_exists() + + +def test_workflow_session_storage(workflow_with_storage, workflow_storage): + """Test that workflow sessions are properly stored.""" + # Run workflow and get response + workflow_with_storage.run(query="What is the capital of France?") + + assert workflow_with_storage.storage.mode == "workflow" + + # Get the session ID from the workflow + session_id = workflow_with_storage.session_id + + # Verify session was stored + stored_session = workflow_storage.read(session_id) + assert stored_session is not None + assert isinstance(stored_session, WorkflowSession) + assert stored_session.session_id == session_id + + # Verify workflow data was stored + assert stored_session.workflow_data is not None + assert stored_session.workflow_data.get("name") == "TestWorkflow" + + # Verify memory contains the run + assert stored_session.memory is not None + assert "runs" in stored_session.memory + assert len(stored_session.memory["runs"]) > 0 + + +def test_multiple_interactions(workflow_with_storage, workflow_storage): + """Test that multiple interactions are properly stored in the same session.""" + # First interaction + workflow_with_storage.run(query="What is the capital of France?") + + assert workflow_with_storage.storage.mode == "workflow" + session_id = workflow_with_storage.session_id + + # Second interaction + workflow_with_storage.run(query="What is its population?") + + # Verify both interactions are in the same session + stored_session = workflow_storage.read(session_id) + assert stored_session is not None + assert "runs" in stored_session.memory + assert len(stored_session.memory["runs"]) == 2 # Should have 2 runs + + +def test_session_retrieval_by_user(workflow_storage): + """Test retrieving sessions filtered by user ID.""" + # Create a session with a specific user ID + workflow = SimpleWorkflow(storage=workflow_storage, user_id="test_user", name="UserTestWorkflow") + workflow.run(query="What is the capital of France?") + + assert workflow.storage.mode == "workflow" + + # Get all sessions for the user + sessions = workflow_storage.get_all_sessions(user_id="test_user") + assert len(sessions) == 1 + assert sessions[0].user_id == "test_user" + + # Verify no sessions for different user + other_sessions = workflow_storage.get_all_sessions(user_id="other_user") + assert len(other_sessions) == 0 + + +def test_session_deletion(workflow_with_storage, workflow_storage): + """Test deleting a session.""" + # Create a session + workflow_with_storage.run(query="What is the capital of France?") + session_id = workflow_with_storage.session_id + + # Verify session exists + assert workflow_storage.read(session_id) is not None + + # Delete session + workflow_storage.delete_session(session_id) + + # Verify session was deleted + assert workflow_storage.read(session_id) is None + + +def test_get_all_session_ids(workflow_storage): + """Test retrieving all session IDs.""" + # Create multiple sessions with different user IDs and workflow IDs + workflow_1 = SimpleWorkflow(storage=workflow_storage, user_id="user1", workflow_id="workflow1", name="Workflow1") + workflow_2 = SimpleWorkflow(storage=workflow_storage, user_id="user1", workflow_id="workflow2", name="Workflow2") + workflow_3 = SimpleWorkflow(storage=workflow_storage, user_id="user2", workflow_id="workflow3", name="Workflow3") + + workflow_1.run(query="Question 1") + workflow_2.run(query="Question 2") + workflow_3.run(query="Question 3") + + # Get all session IDs + all_sessions = workflow_storage.get_all_session_ids() + assert len(all_sessions) == 3 + + # Filter by user ID + user1_sessions = workflow_storage.get_all_session_ids(user_id="user1") + assert len(user1_sessions) == 2 + + # Filter by workflow ID + workflow1_sessions = workflow_storage.get_all_session_ids(entity_id="workflow1") + assert len(workflow1_sessions) == 1 + + # Filter by both + filtered_sessions = workflow_storage.get_all_session_ids(user_id="user1", entity_id="workflow2") + assert len(filtered_sessions) == 1 + + +def test_drop_storage(workflow_storage): + """Test dropping all sessions from storage.""" + # Create a few sessions + for i in range(3): + workflow = SimpleWorkflow(storage=workflow_storage, name=f"Workflow{i}") + workflow.run(query=f"Question {i}") + + # Verify sessions exist + assert len(workflow_storage.get_all_session_ids()) == 3 + + # Drop all sessions + workflow_storage.drop() + + # Verify no sessions remain + assert len(workflow_storage.get_all_session_ids()) == 0 + + +def test_workflow_session_rename(workflow_with_storage, workflow_storage): + """Test renaming a workflow session.""" + # Create a session + workflow_with_storage.run(query="What is the capital of France?") + session_id = workflow_with_storage.session_id + + # Rename the session + workflow_with_storage.rename_session("My Renamed Session") + + # Verify session was renamed + stored_session = workflow_storage.read(session_id) + assert stored_session is not None + assert stored_session.session_data.get("session_name") == "My Renamed Session" + + +def test_workflow_rename(workflow_with_storage, workflow_storage): + """Test renaming a workflow.""" + # Create a session + workflow_with_storage.run(query="What is the capital of France?") + session_id = workflow_with_storage.session_id + + # Rename the workflow + workflow_with_storage.rename("My Renamed Workflow") + + # Verify workflow was renamed + stored_session = workflow_storage.read(session_id) + assert stored_session is not None + assert stored_session.workflow_data.get("name") == "My Renamed Workflow" diff --git a/libs/agno/tests/integration/storage/test_yaml_storage_agent.py b/libs/agno/tests/integration/storage/test_yaml_storage_agent.py new file mode 100644 index 0000000000..7b14b76590 --- /dev/null +++ b/libs/agno/tests/integration/storage/test_yaml_storage_agent.py @@ -0,0 +1,151 @@ +import shutil + +import pytest + +from agno.agent import Agent +from agno.storage.session.agent import AgentSession +from agno.storage.yaml import YamlStorage + + +@pytest.fixture +def temp_storage_path(tmp_path): + """Create a temporary directory for storage that's cleaned up after tests.""" + storage_dir = tmp_path / "test_storage" + storage_dir.mkdir() + yield storage_dir + shutil.rmtree(storage_dir) + + +@pytest.fixture +def agent_storage(temp_storage_path): + """Create a YamlStorage instance for agent sessions.""" + return YamlStorage(dir_path=temp_storage_path, mode="agent") + + +@pytest.fixture +def workflow_storage(temp_storage_path): + """Create a YamlStorage instance for workflow sessions.""" + return YamlStorage(dir_path=temp_storage_path / "workflows", mode="workflow") + + +@pytest.fixture +def agent_with_storage(agent_storage): + """Create an agent with the test storage.""" + return Agent(storage=agent_storage, add_history_to_messages=True) + + +def test_storage_creation(temp_storage_path): + """Test that storage directory is created.""" + YamlStorage(dir_path=temp_storage_path) + assert temp_storage_path.exists() + assert temp_storage_path.is_dir() + + +def test_agent_session_storage(agent_with_storage, agent_storage): + """Test that agent sessions are properly stored.""" + # Run agent and get response + agent_with_storage.run("What is the capital of France?") + + # Get the session ID from the agent + session_id = agent_with_storage.session_id + + # Verify session was stored + stored_session = agent_storage.read(session_id) + assert stored_session is not None + assert isinstance(stored_session, AgentSession) + assert stored_session.session_id == session_id + + # Verify session contains the interaction + assert len(stored_session.memory["messages"]) > 0 + + +def test_multiple_interactions(agent_with_storage, agent_storage): + """Test that multiple interactions are properly stored in the same session.""" + # First interaction + agent_with_storage.run("What is the capital of France?") + session_id = agent_with_storage.session_id + + # Second interaction + agent_with_storage.run("What is its population?") + + # Verify both interactions are in the same session + stored_session = agent_storage.read(session_id) + assert stored_session is not None + assert len(stored_session.memory["messages"]) >= 4 # Should have at least 4 messages (2 questions + 2 responses) + + +def test_session_retrieval_by_user(agent_with_storage, agent_storage): + """Test retrieving sessions filtered by user ID.""" + # Create a session with a specific user ID + agent_with_storage.user_id = "test_user" + agent_with_storage.run("What is the capital of France?") + + # Get all sessions for the user + sessions = agent_storage.get_all_sessions(user_id="test_user") + assert len(sessions) == 1 + assert sessions[0].user_id == "test_user" + + # Verify no sessions for different user + other_sessions = agent_storage.get_all_sessions(user_id="other_user") + assert len(other_sessions) == 0 + + +def test_session_deletion(agent_with_storage, agent_storage): + """Test deleting a session.""" + # Create a session + agent_with_storage.run("What is the capital of France?") + session_id = agent_with_storage.session_id + + # Verify session exists + assert agent_storage.read(session_id) is not None + + # Delete session + agent_storage.delete_session(session_id) + + # Verify session was deleted + assert agent_storage.read(session_id) is None + + +def test_get_all_session_ids(agent_storage): + """Test retrieving all session IDs.""" + # Create multiple sessions with different user IDs and agent IDs + agent_1 = Agent(storage=agent_storage, user_id="user1", agent_id="agent1", add_history_to_messages=True) + agent_2 = Agent(storage=agent_storage, user_id="user1", agent_id="agent2", add_history_to_messages=True) + agent_3 = Agent(storage=agent_storage, user_id="user2", agent_id="agent3", add_history_to_messages=True) + + agent_1.run("Question 1") + agent_2.run("Question 2") + agent_3.run("Question 3") + + # Get all session IDs + all_sessions = agent_storage.get_all_session_ids() + assert len(all_sessions) == 3 + + # Filter by user ID + user1_sessions = agent_storage.get_all_session_ids(user_id="user1") + assert len(user1_sessions) == 2 + + # Filter by agent ID + agent1_sessions = agent_storage.get_all_session_ids(entity_id="agent1") + assert len(agent1_sessions) == 1 + + # Filter by both + filtered_sessions = agent_storage.get_all_session_ids(user_id="user1", entity_id="agent2") + assert len(filtered_sessions) == 1 + + +def test_drop_storage(agent_with_storage, agent_storage): + """Test dropping all sessions from storage.""" + # Create a few sessions + for i in range(3): + agent = Agent(storage=agent_storage, add_history_to_messages=True) + agent.run(f"Question {i}") + + # Verify sessions exist + assert len(agent_storage.get_all_session_ids()) == 3 + + # Drop all sessions + agent_storage.drop() + + # Verify no sessions remain + assert len(agent_storage.get_all_session_ids()) == 0 diff --git a/libs/agno/tests/integration/storage/test_yaml_storage_workflow.py b/libs/agno/tests/integration/storage/test_yaml_storage_workflow.py new file mode 100644 index 0000000000..d1b146adfe --- /dev/null +++ b/libs/agno/tests/integration/storage/test_yaml_storage_workflow.py @@ -0,0 +1,205 @@ +import shutil + +import pytest + +from agno.agent import Agent +from agno.run.response import RunResponse +from agno.storage.session.workflow import WorkflowSession +from agno.storage.yaml import YamlStorage +from agno.workflow import Workflow + + +@pytest.fixture +def temp_storage_path(tmp_path): + """Create a temporary directory for storage that's cleaned up after tests.""" + storage_dir = tmp_path / "test_storage" + storage_dir.mkdir() + yield storage_dir + shutil.rmtree(storage_dir) + + +@pytest.fixture +def workflow_storage(temp_storage_path): + """Create a YamlStorage instance for workflow sessions.""" + return YamlStorage(dir_path=temp_storage_path, mode="workflow") + + +class SimpleWorkflow(Workflow): + """A simple workflow with a single agent for testing.""" + + description: str = "A simple workflow for testing storage" + + test_agent: Agent = Agent( + description="A test agent for the workflow", + ) + + def run(self, query: str) -> RunResponse: + """Run the workflow with a simple query.""" + response = self.test_agent.run(query) + return RunResponse(run_id=self.run_id, content=f"Workflow processed: {response.content}") + + +@pytest.fixture +def workflow_with_storage(workflow_storage): + """Create a workflow with the test storage.""" + return SimpleWorkflow(storage=workflow_storage, name="TestWorkflow") + + +def test_storage_creation(temp_storage_path): + """Test that storage directory is created.""" + YamlStorage(dir_path=temp_storage_path, mode="workflow") + assert temp_storage_path.exists() + assert temp_storage_path.is_dir() + + +def test_workflow_session_storage(workflow_with_storage, workflow_storage): + """Test that workflow sessions are properly stored.""" + # Run workflow and get response + workflow_with_storage.run(query="What is the capital of France?") + + assert workflow_with_storage.storage.mode == "workflow" + + # Get the session ID from the workflow + session_id = workflow_with_storage.session_id + + # Verify session was stored + stored_session = workflow_storage.read(session_id) + assert stored_session is not None + assert isinstance(stored_session, WorkflowSession) + assert stored_session.session_id == session_id + + # Verify workflow data was stored + assert stored_session.workflow_data is not None + assert stored_session.workflow_data.get("name") == "TestWorkflow" + + # Verify memory contains the run + assert stored_session.memory is not None + assert "runs" in stored_session.memory + assert len(stored_session.memory["runs"]) > 0 + + +def test_multiple_interactions(workflow_with_storage, workflow_storage): + """Test that multiple interactions are properly stored in the same session.""" + # First interaction + workflow_with_storage.run(query="What is the capital of France?") + + assert workflow_with_storage.storage.mode == "workflow" + session_id = workflow_with_storage.session_id + + # Second interaction + workflow_with_storage.run(query="What is its population?") + + # Verify both interactions are in the same session + stored_session = workflow_storage.read(session_id) + assert stored_session is not None + assert "runs" in stored_session.memory + assert len(stored_session.memory["runs"]) == 2 # Should have 2 runs + + +def test_session_retrieval_by_user(workflow_storage): + """Test retrieving sessions filtered by user ID.""" + # Create a session with a specific user ID + workflow = SimpleWorkflow(storage=workflow_storage, user_id="test_user", name="UserTestWorkflow") + workflow.run(query="What is the capital of France?") + + assert workflow.storage.mode == "workflow" + + # Get all sessions for the user + sessions = workflow_storage.get_all_sessions(user_id="test_user") + assert len(sessions) == 1 + assert sessions[0].user_id == "test_user" + + # Verify no sessions for different user + other_sessions = workflow_storage.get_all_sessions(user_id="other_user") + assert len(other_sessions) == 0 + + +def test_session_deletion(workflow_with_storage, workflow_storage): + """Test deleting a session.""" + # Create a session + workflow_with_storage.run(query="What is the capital of France?") + session_id = workflow_with_storage.session_id + + # Verify session exists + assert workflow_storage.read(session_id) is not None + + # Delete session + workflow_storage.delete_session(session_id) + + # Verify session was deleted + assert workflow_storage.read(session_id) is None + + +def test_get_all_session_ids(workflow_storage): + """Test retrieving all session IDs.""" + # Create multiple sessions with different user IDs and workflow IDs + workflow_1 = SimpleWorkflow(storage=workflow_storage, user_id="user1", workflow_id="workflow1", name="Workflow1") + workflow_2 = SimpleWorkflow(storage=workflow_storage, user_id="user1", workflow_id="workflow2", name="Workflow2") + workflow_3 = SimpleWorkflow(storage=workflow_storage, user_id="user2", workflow_id="workflow3", name="Workflow3") + + workflow_1.run(query="Question 1") + workflow_2.run(query="Question 2") + workflow_3.run(query="Question 3") + + # Get all session IDs + all_sessions = workflow_storage.get_all_session_ids() + assert len(all_sessions) == 3 + + # Filter by user ID + user1_sessions = workflow_storage.get_all_session_ids(user_id="user1") + assert len(user1_sessions) == 2 + + # Filter by workflow ID + workflow1_sessions = workflow_storage.get_all_session_ids(entity_id="workflow1") + assert len(workflow1_sessions) == 1 + + # Filter by both + filtered_sessions = workflow_storage.get_all_session_ids(user_id="user1", entity_id="workflow2") + assert len(filtered_sessions) == 1 + + +def test_drop_storage(workflow_storage): + """Test dropping all sessions from storage.""" + # Create a few sessions + for i in range(3): + workflow = SimpleWorkflow(storage=workflow_storage, name=f"Workflow{i}") + workflow.run(query=f"Question {i}") + + # Verify sessions exist + assert len(workflow_storage.get_all_session_ids()) == 3 + + # Drop all sessions + workflow_storage.drop() + + # Verify no sessions remain + assert len(workflow_storage.get_all_session_ids()) == 0 + + +def test_workflow_session_rename(workflow_with_storage, workflow_storage): + """Test renaming a workflow session.""" + # Create a session + workflow_with_storage.run(query="What is the capital of France?") + session_id = workflow_with_storage.session_id + + # Rename the session + workflow_with_storage.rename_session("My Renamed Session") + + # Verify session was renamed + stored_session = workflow_storage.read(session_id) + assert stored_session is not None + assert stored_session.session_data.get("session_name") == "My Renamed Session" + + +def test_workflow_rename(workflow_with_storage, workflow_storage): + """Test renaming a workflow.""" + # Create a session + workflow_with_storage.run(query="What is the capital of France?") + session_id = workflow_with_storage.session_id + + # Rename the workflow + workflow_with_storage.rename("My Renamed Workflow") + + # Verify workflow was renamed + stored_session = workflow_storage.read(session_id) + assert stored_session is not None + assert stored_session.workflow_data.get("name") == "My Renamed Workflow" diff --git a/libs/agno/tests/unit/__init__.py b/libs/agno/tests/unit/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/libs/agno/tests/unit/reader/__init__.py b/libs/agno/tests/unit/reader/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/libs/agno/tests/unit/storage/__init__.py b/libs/agno/tests/unit/storage/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/libs/agno/tests/unit/storage/test_dynamodb_storage.py b/libs/agno/tests/unit/storage/test_dynamodb_storage.py new file mode 100644 index 0000000000..c759a6a5fb --- /dev/null +++ b/libs/agno/tests/unit/storage/test_dynamodb_storage.py @@ -0,0 +1,382 @@ +from unittest.mock import MagicMock, patch + +import pytest +from boto3.dynamodb.conditions import Key + +from agno.storage.dynamodb import DynamoDbStorage +from agno.storage.session.agent import AgentSession +from agno.storage.session.workflow import WorkflowSession + + +@pytest.fixture +def mock_dynamodb_resource(): + """Create a mock boto3 DynamoDB resource.""" + with patch("agno.storage.dynamodb.boto3.resource") as mock_resource: + mock_table = MagicMock() + mock_resource.return_value.Table.return_value = mock_table + yield mock_resource, mock_table + + +@pytest.fixture +def agent_storage(mock_dynamodb_resource): + """Create a DynamoDbStorage instance for agent mode with mocked components.""" + mock_resource, mock_table = mock_dynamodb_resource + + # Mock table.wait_until_exists to avoid actual waiting + mock_table.wait_until_exists = MagicMock() + + # Create storage with create_table_if_not_exists=False to avoid table creation + storage = DynamoDbStorage( + table_name="agent_sessions", region_name="us-east-1", create_table_if_not_exists=False, mode="agent" + ) + + return storage, mock_table + + +@pytest.fixture +def workflow_storage(mock_dynamodb_resource): + """Create a DynamoDbStorage instance for workflow mode with mocked components.""" + mock_resource, mock_table = mock_dynamodb_resource + + # Mock table.wait_until_exists to avoid actual waiting + mock_table.wait_until_exists = MagicMock() + + # Create storage with create_table_if_not_exists=False to avoid table creation + storage = DynamoDbStorage( + table_name="workflow_sessions", region_name="us-east-1", create_table_if_not_exists=False, mode="workflow" + ) + + return storage, mock_table + + +def test_initialization(): + """Test DynamoDbStorage initialization with different parameters.""" + # Test with region_name + with patch("agno.storage.dynamodb.boto3.resource") as mock_resource: + mock_table = MagicMock() + mock_resource.return_value.Table.return_value = mock_table + mock_table.wait_until_exists = MagicMock() + + storage = DynamoDbStorage(table_name="test_table", region_name="us-west-2", create_table_if_not_exists=False) + + mock_resource.assert_called_once_with( + "dynamodb", region_name="us-west-2", aws_access_key_id=None, aws_secret_access_key=None, endpoint_url=None + ) + assert storage.table_name == "test_table" + assert storage.mode == "agent" # Default value + + # Test with credentials + with patch("agno.storage.dynamodb.boto3.resource") as mock_resource: + mock_table = MagicMock() + mock_resource.return_value.Table.return_value = mock_table + mock_table.wait_until_exists = MagicMock() + + storage = DynamoDbStorage( + table_name="test_table", + region_name="us-west-2", + aws_access_key_id="test-key", + aws_secret_access_key="test-secret", + create_table_if_not_exists=False, + ) + + mock_resource.assert_called_once_with( + "dynamodb", + region_name="us-west-2", + aws_access_key_id="test-key", + aws_secret_access_key="test-secret", + endpoint_url=None, + ) + + # Test with endpoint_url (for local testing) + with patch("agno.storage.dynamodb.boto3.resource") as mock_resource: + mock_table = MagicMock() + mock_resource.return_value.Table.return_value = mock_table + mock_table.wait_until_exists = MagicMock() + + storage = DynamoDbStorage( + table_name="test_table", endpoint_url="http://localhost:8000", create_table_if_not_exists=False + ) + + mock_resource.assert_called_once_with( + "dynamodb", + region_name=None, + aws_access_key_id=None, + aws_secret_access_key=None, + endpoint_url="http://localhost:8000", + ) + + +def test_agent_storage_crud(agent_storage): + """Test CRUD operations for agent storage.""" + storage, mock_table = agent_storage + + # Create a test session + session = AgentSession( + session_id="test-session", + agent_id="test-agent", + user_id="test-user", + memory={"key": "value"}, + agent_data={"name": "Test Agent"}, + session_data={"state": "active"}, + extra_data={"custom": "data"}, + ) + + # Test upsert + mock_table.put_item.return_value = {} # DynamoDB put_item returns empty dict on success + mock_table.get_item.return_value = {"Item": session.to_dict()} # Mock the read after upsert + + result = storage.upsert(session) + assert result is not None + assert result.session_id == session.session_id + assert result.agent_id == session.agent_id + mock_table.put_item.assert_called_once() + + # Test read + mock_table.get_item.reset_mock() + session_dict = session.to_dict() + mock_table.get_item.return_value = {"Item": session_dict} + read_result = storage.read("test-session") + assert read_result is not None + assert read_result.session_id == session.session_id + assert read_result.agent_id == session.agent_id + assert read_result.user_id == session.user_id + mock_table.get_item.assert_called_once_with(Key={"session_id": "test-session"}) + + # Test read with non-existent session + mock_table.get_item.reset_mock() + mock_table.get_item.return_value = {} # DynamoDB returns empty dict when item not found + read_result = storage.read("non-existent-session") + assert read_result is None + mock_table.get_item.assert_called_once_with(Key={"session_id": "non-existent-session"}) + + # Test delete + mock_table.delete_item.return_value = {} # DynamoDB delete_item returns empty dict on success + storage.delete_session("test-session") + mock_table.delete_item.assert_called_once_with(Key={"session_id": "test-session"}) + + +def test_workflow_storage_crud(workflow_storage): + """Test CRUD operations for workflow storage.""" + storage, mock_table = workflow_storage + + # Create a test session + session = WorkflowSession( + session_id="test-session", + workflow_id="test-workflow", + user_id="test-user", + memory={"key": "value"}, + workflow_data={"name": "Test Workflow"}, + session_data={"state": "active"}, + extra_data={"custom": "data"}, + ) + + # Mock the read method + original_read = storage.read + storage.read = MagicMock(return_value=session) + + # Test upsert + result = storage.upsert(session) + assert result == session + mock_table.put_item.assert_called_once() + + # Test read + mock_table.get_item.return_value = {"Item": session.to_dict()} + storage.read = original_read + read_result = storage.read("test-session") + assert read_result is not None + assert read_result.session_id == session.session_id + + # Test delete + storage.delete_session = MagicMock() + storage.delete_session("test-session") + storage.delete_session.assert_called_once_with("test-session") + + +def test_get_all_sessions(agent_storage): + """Test retrieving all sessions.""" + storage, mock_table = agent_storage + + # Create mock sessions + sessions = [] + for i in range(4): + session_data = { + "session_id": f"session-{i}", + "agent_id": f"agent-{i % 2 + 1}", + "user_id": f"user-{i % 2 + 1}", + "memory": {}, + "agent_data": {}, + "session_data": {}, + "extra_data": {}, + "created_at": 1000000, + "updated_at": None, + } + sessions.append(session_data) + + # Mock scan response for unfiltered query + mock_table.scan.return_value = {"Items": sessions} + + # Test get_all_sessions without filters + result = storage.get_all_sessions() + assert len(result) == 4 + assert all(isinstance(s, AgentSession) for s in result) + mock_table.scan.assert_called_once_with( + ProjectionExpression="session_id, agent_id, user_id, memory, agent_data, session_data, extra_data, created_at, updated_at" + ) + + # Test filtering by user_id + mock_table.scan.reset_mock() + mock_table.query.reset_mock() + user1_sessions = [s for s in sessions if s["user_id"] == "user-1"] + mock_table.query.return_value = {"Items": user1_sessions} + + result = storage.get_all_sessions(user_id="user-1") + assert len(result) == 2 + assert all(s.user_id == "user-1" for s in result) + mock_table.query.assert_called_once_with( + IndexName="user_id-index", + KeyConditionExpression=Key("user_id").eq("user-1"), + ProjectionExpression="session_id, agent_id, user_id, memory, agent_data, session_data, extra_data, created_at, updated_at", + ) + + # Test filtering by agent_id + mock_table.query.reset_mock() + agent1_sessions = [s for s in sessions if s["agent_id"] == "agent-1"] + mock_table.query.return_value = {"Items": agent1_sessions} + + result = storage.get_all_sessions(entity_id="agent-1") + assert len(result) == 2 + assert all(s.agent_id == "agent-1" for s in result) + mock_table.query.assert_called_once_with( + IndexName="agent_id-index", + KeyConditionExpression=Key("agent_id").eq("agent-1"), + ProjectionExpression="session_id, agent_id, user_id, memory, agent_data, session_data, extra_data, created_at, updated_at", + ) + + +def test_get_all_session_ids(agent_storage): + """Test retrieving all session IDs.""" + storage, mock_table = agent_storage + + # Mock the scan method to return session IDs + mock_response = {"Items": [{"session_id": "session-1"}, {"session_id": "session-2"}, {"session_id": "session-3"}]} + mock_table.scan.return_value = mock_response + + # Test get_all_session_ids without filters + result = storage.get_all_session_ids() + assert result == ["session-1", "session-2", "session-3"] + mock_table.scan.assert_called_once_with(ProjectionExpression="session_id") + + # Test with user_id filter + mock_table.scan.reset_mock() + mock_table.query.return_value = mock_response + + result = storage.get_all_session_ids(user_id="test-user") + assert result == ["session-1", "session-2", "session-3"] + mock_table.query.assert_called_once_with( + IndexName="user_id-index", + KeyConditionExpression=Key("user_id").eq("test-user"), + ProjectionExpression="session_id", + ) + + # Test with entity_id filter (agent_id in agent mode) + mock_table.query.reset_mock() + mock_table.query.return_value = mock_response + + result = storage.get_all_session_ids(entity_id="test-agent") + assert result == ["session-1", "session-2", "session-3"] + mock_table.query.assert_called_once_with( + IndexName="agent_id-index", + KeyConditionExpression=Key("agent_id").eq("test-agent"), + ProjectionExpression="session_id", + ) + + +def test_drop_table(agent_storage): + """Test dropping a table.""" + storage, mock_table = agent_storage + + # Mock the delete and wait_until_not_exists methods + mock_table.delete = MagicMock() + mock_table.wait_until_not_exists = MagicMock() + + # Call drop + storage.drop() + + # Verify delete was called + mock_table.delete.assert_called_once() + mock_table.wait_until_not_exists.assert_called_once() + + +def test_mode_switching(): + """Test switching between agent and workflow modes.""" + with patch("agno.storage.dynamodb.boto3.resource") as mock_resource: + mock_table = MagicMock() + mock_resource.return_value.Table.return_value = mock_table + mock_table.wait_until_exists = MagicMock() + + # Create storage in agent mode + storage = DynamoDbStorage(table_name="test_table", create_table_if_not_exists=False) + assert storage.mode == "agent" + + # Switch to workflow mode + with patch.object(storage, "create") as mock_create: + storage.mode = "workflow" + assert storage.mode == "workflow" + # Since create_table_if_not_exists is False, create should not be called + mock_create.assert_not_called() + + # Test with create_table_if_not_exists=True + storage.create_table_if_not_exists = True + with patch.object(storage, "create") as mock_create: + storage.mode = "agent" + assert storage.mode == "agent" + mock_create.assert_called_once() + + +def test_serialization_deserialization(agent_storage): + """Test serialization and deserialization of items.""" + storage, _ = agent_storage + + # Test serialization + test_item = { + "int_value": 42, + "float_value": 3.14, + "str_value": "test", + "bool_value": True, + "list_value": [1, 2, 3], + "dict_value": {"key": "value"}, + "nested_dict": {"nested": {"float": 1.23, "list": [4, 5, 6]}}, + "none_value": None, + } + + serialized = storage._serialize_item(test_item) + + # None values should be removed + assert "none_value" not in serialized + + # Test deserialization + from decimal import Decimal + + decimal_item = { + "int_value": Decimal("42"), + "float_value": Decimal("3.14"), + "str_value": "test", + "bool_value": True, + "list_value": [Decimal("1"), Decimal("2"), Decimal("3")], + "dict_value": {"key": "value"}, + "nested_dict": {"nested": {"float": Decimal("1.23"), "list": [Decimal("4"), Decimal("5"), Decimal("6")]}}, + } + + deserialized = storage._deserialize_item(decimal_item) + + # Decimals should be converted to int or float + assert isinstance(deserialized["int_value"], int) + assert deserialized["int_value"] == 42 + + assert isinstance(deserialized["float_value"], float) + assert deserialized["float_value"] == 3.14 + + # Nested values should also be converted + assert isinstance(deserialized["list_value"][0], int) + assert isinstance(deserialized["nested_dict"]["nested"]["float"], float) + assert isinstance(deserialized["nested_dict"]["nested"]["list"][0], int) diff --git a/libs/agno/tests/unit/storage/test_json_storage.py b/libs/agno/tests/unit/storage/test_json_storage.py new file mode 100644 index 0000000000..4c6e3a6ef0 --- /dev/null +++ b/libs/agno/tests/unit/storage/test_json_storage.py @@ -0,0 +1,176 @@ +import tempfile +from pathlib import Path +from typing import Generator + +import pytest + +from agno.storage.json import JsonStorage +from agno.storage.session.agent import AgentSession +from agno.storage.session.workflow import WorkflowSession + + +@pytest.fixture +def temp_dir() -> Generator[Path, None, None]: + with tempfile.TemporaryDirectory() as temp_dir: + yield Path(temp_dir) + + +@pytest.fixture +def agent_storage(temp_dir: Path) -> JsonStorage: + return JsonStorage(dir_path=temp_dir) + + +@pytest.fixture +def workflow_storage(temp_dir: Path) -> JsonStorage: + return JsonStorage(dir_path=temp_dir, mode="workflow") + + +def test_agent_storage_crud(agent_storage: JsonStorage, temp_dir: Path): + # Test create + agent_storage.create() + assert temp_dir.exists() + + # Test upsert + session = AgentSession( + session_id="test-session", + agent_id="test-agent", + user_id="test-user", + memory={"key": "value"}, + agent_data={"name": "Test Agent"}, + session_data={"state": "active"}, + extra_data={"custom": "data"}, + ) + + saved_session = agent_storage.upsert(session) + assert saved_session is not None + assert saved_session.session_id == session.session_id + assert (temp_dir / "test-session.json").exists() + + # Test read + read_session = agent_storage.read("test-session") + assert read_session is not None + assert read_session.session_id == session.session_id + assert read_session.agent_id == session.agent_id + assert read_session.memory == session.memory + + # Test get all sessions + all_sessions = agent_storage.get_all_sessions() + assert len(all_sessions) == 1 + assert all_sessions[0].session_id == session.session_id + + # Test delete + agent_storage.delete_session("test-session") + assert agent_storage.read("test-session") is None + assert not (temp_dir / "test-session.json").exists() + + +def test_workflow_storage_crud(workflow_storage: JsonStorage, temp_dir: Path): + # Test create + workflow_storage.create() + assert temp_dir.exists() + + # Test upsert + session = WorkflowSession( + session_id="test-session", + workflow_id="test-workflow", + user_id="test-user", + memory={"key": "value"}, + workflow_data={"name": "Test Workflow"}, + session_data={"state": "active"}, + extra_data={"custom": "data"}, + ) + + saved_session = workflow_storage.upsert(session) + assert saved_session is not None + assert saved_session.session_id == session.session_id + assert (temp_dir / "test-session.json").exists() + + # Test read + read_session = workflow_storage.read("test-session") + assert read_session is not None + assert read_session.session_id == session.session_id + assert read_session.workflow_id == session.workflow_id + assert read_session.memory == session.memory + + # Test get all sessions + all_sessions = workflow_storage.get_all_sessions() + assert len(all_sessions) == 1 + assert all_sessions[0].session_id == session.session_id + + # Test delete + workflow_storage.delete_session("test-session") + assert workflow_storage.read("test-session") is None + assert not (temp_dir / "test-session.json").exists() + + +def test_storage_filtering(agent_storage: JsonStorage): + # Create test sessions + sessions = [ + AgentSession( + session_id=f"session-{i}", + agent_id="agent-1" if i < 2 else "agent-2", + user_id="user-1" if i % 2 == 0 else "user-2", + ) + for i in range(4) + ] + + for session in sessions: + agent_storage.upsert(session) + + # Test filtering by user_id + user1_sessions = agent_storage.get_all_sessions(user_id="user-1") + assert len(user1_sessions) == 2 + assert all(s.user_id == "user-1" for s in user1_sessions) + + # Test filtering by agent_id + agent1_sessions = agent_storage.get_all_sessions(entity_id="agent-1") + assert len(agent1_sessions) == 2 + assert all(s.agent_id == "agent-1" for s in agent1_sessions) + + # Test combined filtering + filtered_sessions = agent_storage.get_all_sessions(user_id="user-1", entity_id="agent-1") + assert len(filtered_sessions) == 1 + assert filtered_sessions[0].user_id == "user-1" + assert filtered_sessions[0].agent_id == "agent-1" + + +def test_workflow_storage_filtering(workflow_storage: JsonStorage): + # Create test sessions + sessions = [ + WorkflowSession( + session_id=f"session-{i}", + workflow_id="workflow-1" if i < 2 else "workflow-2", + user_id="user-1" if i % 2 == 0 else "user-2", + memory={"key": f"value-{i}"}, + workflow_data={"name": f"Test Workflow {i}"}, + session_data={"state": "active"}, + extra_data={"custom": f"data-{i}"}, + ) + for i in range(4) + ] + + for session in sessions: + workflow_storage.upsert(session) + + # Test filtering by user_id + user1_sessions = workflow_storage.get_all_sessions(user_id="user-1") + assert len(user1_sessions) == 2 + assert all(s.user_id == "user-1" for s in user1_sessions) + + # Test filtering by workflow_id + workflow1_sessions = workflow_storage.get_all_sessions(entity_id="workflow-1") + assert len(workflow1_sessions) == 2 + assert all(s.workflow_id == "workflow-1" for s in workflow1_sessions) + + # Test combined filtering + filtered_sessions = workflow_storage.get_all_sessions(user_id="user-1", entity_id="workflow-1") + assert len(filtered_sessions) == 1 + assert filtered_sessions[0].user_id == "user-1" + assert filtered_sessions[0].workflow_id == "workflow-1" + + # Test filtering with non-existent IDs + empty_sessions = workflow_storage.get_all_sessions(user_id="non-existent") + assert len(empty_sessions) == 0 + + empty_sessions = workflow_storage.get_all_sessions(entity_id="non-existent") + assert len(empty_sessions) == 0 diff --git a/libs/agno/tests/unit/storage/test_mongodb_storage.py b/libs/agno/tests/unit/storage/test_mongodb_storage.py new file mode 100644 index 0000000000..990dd020bf --- /dev/null +++ b/libs/agno/tests/unit/storage/test_mongodb_storage.py @@ -0,0 +1,318 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from agno.storage.mongodb import MongoDbStorage +from agno.storage.session.agent import AgentSession +from agno.storage.session.workflow import WorkflowSession + + +@pytest.fixture +def mock_mongo_client(): + """Create a mock MongoDB client.""" + with patch("agno.storage.mongodb.MongoClient") as mock_client: + mock_collection = MagicMock() + mock_client.return_value.__getitem__.return_value.__getitem__.return_value = mock_collection + yield mock_client, mock_collection + + +@pytest.fixture +def agent_storage(mock_mongo_client): + """Create a MongoDbStorage instance for agent mode with mocked components.""" + mock_client, mock_collection = mock_mongo_client + + storage = MongoDbStorage(collection_name="agent_sessions", db_name="test_db", mode="agent") + + return storage, mock_collection + + +@pytest.fixture +def workflow_storage(mock_mongo_client): + """Create a MongoDbStorage instance for workflow mode with mocked components.""" + mock_client, mock_collection = mock_mongo_client + + storage = MongoDbStorage(collection_name="workflow_sessions", db_name="test_db", mode="workflow") + + return storage, mock_collection + + +def test_initialization(): + """Test MongoDbStorage initialization with different parameters.""" + # Test with db_url + with patch("agno.storage.mongodb.MongoClient") as mock_client: + mock_collection = MagicMock() + mock_client.return_value.__getitem__.return_value.__getitem__.return_value = mock_collection + + storage = MongoDbStorage( + collection_name="test_collection", db_url="mongodb://localhost:27017", db_name="test_db" + ) + + mock_client.assert_called_once_with("mongodb://localhost:27017") + assert storage.collection_name == "test_collection" + assert storage.db_name == "test_db" + assert storage.mode == "agent" # Default value + + # Test with existing client + with patch("agno.storage.mongodb.MongoClient") as mock_client: + mock_existing_client = MagicMock() + mock_collection = MagicMock() + mock_existing_client.__getitem__.return_value.__getitem__.return_value = mock_collection + + storage = MongoDbStorage(collection_name="test_collection", db_name="test_db", client=mock_existing_client) + + mock_client.assert_not_called() # Should not create a new client + assert storage.collection_name == "test_collection" + assert storage.db_name == "test_db" + + # Test with no parameters + with patch("agno.storage.mongodb.MongoClient") as mock_client: + mock_collection = MagicMock() + mock_client.return_value.__getitem__.return_value.__getitem__.return_value = mock_collection + + storage = MongoDbStorage(collection_name="test_collection") + + mock_client.assert_called_once() # Should create a default client + assert storage.collection_name == "test_collection" + assert storage.db_name == "agno" # Default value + + +def test_create_indexes(agent_storage): + """Test creating indexes.""" + storage, mock_collection = agent_storage + + # Mock create_index + mock_collection.create_index = MagicMock() + + # Call create + storage.create() + + # Verify create_index was called for each index + assert mock_collection.create_index.call_count >= 4 # At least 4 indexes + + # Verify agent_id index is created in agent mode + mock_collection.create_index.assert_any_call("agent_id") + + # Test in workflow mode + storage.mode = "workflow" + mock_collection.create_index.reset_mock() + + storage.create() + + # Verify workflow_id index is created in workflow mode + mock_collection.create_index.assert_any_call("workflow_id") + + +def test_agent_storage_crud(agent_storage): + """Test CRUD operations for agent storage.""" + storage, mock_collection = agent_storage + + # Create a test session + session = AgentSession( + session_id="test-session", + agent_id="test-agent", + user_id="test-user", + memory={"key": "value"}, + agent_data={"name": "Test Agent"}, + session_data={"state": "active"}, + extra_data={"custom": "data"}, + ) + + # Mock find_one for read + mock_collection.find_one.return_value = {**session.to_dict(), "_id": "mock_id"} + + # Test read + read_result = storage.read("test-session") + assert read_result is not None + assert read_result.session_id == session.session_id + + # Mock update_one for upsert + mock_collection.update_one.return_value = MagicMock(acknowledged=True) + + # Mock the read method for upsert + original_read = storage.read + storage.read = MagicMock(return_value=session) + + # Test upsert + result = storage.upsert(session) + assert result == session + mock_collection.update_one.assert_called_once() + + # Restore original read method + storage.read = original_read + + # Test delete + storage.delete_session = MagicMock() + storage.delete_session("test-session") + storage.delete_session.assert_called_once_with("test-session") + + +def test_workflow_storage_crud(workflow_storage): + """Test CRUD operations for workflow storage.""" + storage, mock_collection = workflow_storage + + # Create a test session + session = WorkflowSession( + session_id="test-session", + workflow_id="test-workflow", + user_id="test-user", + memory={"key": "value"}, + workflow_data={"name": "Test Workflow"}, + session_data={"state": "active"}, + extra_data={"custom": "data"}, + ) + + # Mock find_one for read + mock_collection.find_one.return_value = {**session.to_dict(), "_id": "mock_id"} + + # Test read + read_result = storage.read("test-session") + assert read_result is not None + assert read_result.session_id == session.session_id + + # Mock update_one for upsert + mock_collection.update_one.return_value = MagicMock(acknowledged=True) + + # Mock the read method for upsert + original_read = storage.read + storage.read = MagicMock(return_value=session) + + # Test upsert + result = storage.upsert(session) + assert result == session + mock_collection.update_one.assert_called_once() + + # Restore original read method + storage.read = original_read + + # Test delete + storage.delete_session = MagicMock() + storage.delete_session("test-session") + storage.delete_session.assert_called_once_with("test-session") + + +def test_get_all_sessions(agent_storage): + """Test retrieving all sessions.""" + storage, mock_collection = agent_storage + + # Create mock sessions + sessions = [ + AgentSession( + session_id=f"session-{i}", + agent_id=f"agent-{i % 2 + 1}", + user_id=f"user-{i % 2 + 1}", + ) + for i in range(4) + ] + + # Mock the get_all_sessions method directly + original_get_all_sessions = storage.get_all_sessions + storage.get_all_sessions = MagicMock(return_value=sessions) + + # Test get_all_sessions + result = storage.get_all_sessions() + assert len(result) == 4 + + # Test filtering by user_id + user1_sessions = [s for s in sessions if s.user_id == "user-1"] + storage.get_all_sessions = MagicMock(return_value=user1_sessions) + + result = storage.get_all_sessions(user_id="user-1") + assert len(result) == 2 + assert all(s.user_id == "user-1" for s in result) + + # Test filtering by agent_id + agent1_sessions = [s for s in sessions if s.agent_id == "agent-1"] + storage.get_all_sessions = MagicMock(return_value=agent1_sessions) + + result = storage.get_all_sessions(entity_id="agent-1") + assert len(result) == 2 + assert all(s.agent_id == "agent-1" for s in result) + + # Restore original method + storage.get_all_sessions = original_get_all_sessions + + +def test_get_all_session_ids(agent_storage): + """Test retrieving all session IDs.""" + storage, mock_collection = agent_storage + + # Mock the find method to return session IDs + mock_cursor = MagicMock() + mock_cursor.sort.return_value = [ + {"session_id": "session-1"}, + {"session_id": "session-2"}, + {"session_id": "session-3"}, + ] + mock_collection.find.return_value = mock_cursor + + # Test get_all_session_ids without filters + result = storage.get_all_session_ids() + assert result == ["session-1", "session-2", "session-3"] + mock_collection.find.assert_called_once_with({}, {"session_id": 1}) + + # Test with user_id filter + mock_collection.find.reset_mock() + mock_cursor.sort.return_value = [{"session_id": "session-1"}, {"session_id": "session-2"}] + mock_collection.find.return_value = mock_cursor + + result = storage.get_all_session_ids(user_id="test-user") + assert result == ["session-1", "session-2"] + mock_collection.find.assert_called_once_with({"user_id": "test-user"}, {"session_id": 1}) + + # Test with entity_id filter (agent_id in agent mode) + mock_collection.find.reset_mock() + mock_cursor.sort.return_value = [{"session_id": "session-3"}] + mock_collection.find.return_value = mock_cursor + + result = storage.get_all_session_ids(entity_id="test-agent") + assert result == ["session-3"] + mock_collection.find.assert_called_once_with({"agent_id": "test-agent"}, {"session_id": 1}) + + +def test_drop_collection(agent_storage): + """Test dropping a collection.""" + storage, mock_collection = agent_storage + + # Mock the drop method + mock_collection.drop = MagicMock() + + # Call drop + storage.drop() + + # Verify drop was called + mock_collection.drop.assert_called_once() + + +def test_mode_switching(): + """Test switching between agent and workflow modes.""" + with patch("agno.storage.mongodb.MongoClient") as mock_client: + mock_collection = MagicMock() + mock_client.return_value.__getitem__.return_value.__getitem__.return_value = mock_collection + + # Create storage in agent mode + storage = MongoDbStorage(collection_name="test_collection") + assert storage.mode == "agent" + + # Switch to workflow mode + storage.mode = "workflow" + assert storage.mode == "workflow" + + +def test_deepcopy(agent_storage): + """Test deep copying the storage instance.""" + from copy import deepcopy + + storage, _ = agent_storage + + # Deep copy the storage + copied_storage = deepcopy(storage) + + # Verify the copy has the same attributes + assert copied_storage.collection_name == storage.collection_name + assert copied_storage.db_name == storage.db_name + assert copied_storage.mode == storage.mode + + # Verify the copy shares the same client, db, and collection references + assert copied_storage._client is storage._client + assert copied_storage.db is storage.db + assert copied_storage.collection is storage.collection diff --git a/libs/agno/tests/unit/storage/test_postgres_storage.py b/libs/agno/tests/unit/storage/test_postgres_storage.py new file mode 100644 index 0000000000..cacabd051d --- /dev/null +++ b/libs/agno/tests/unit/storage/test_postgres_storage.py @@ -0,0 +1,274 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from agno.storage.postgres import PostgresStorage +from agno.storage.session.agent import AgentSession +from agno.storage.session.workflow import WorkflowSession + + +@pytest.fixture +def mock_engine(): + """Create a mock SQLAlchemy engine.""" + engine = MagicMock() + return engine + + +@pytest.fixture +def mock_session(): + """Create a mock SQLAlchemy session.""" + session = MagicMock() + session_instance = MagicMock() + session.return_value.__enter__.return_value = session_instance + return session, session_instance + + +@pytest.fixture +def agent_storage(mock_engine, mock_session): + """Create a PostgresStorage instance for agent mode with mocked components.""" + with patch("agno.storage.postgres.scoped_session", return_value=mock_session[0]): + with patch("agno.storage.postgres.inspect", return_value=MagicMock()): + storage = PostgresStorage(table_name="agent_sessions", schema="ai", db_engine=mock_engine, mode="agent") + # Mock table_exists to return True + storage.table_exists = MagicMock(return_value=True) + return storage, mock_session[1] + + +@pytest.fixture +def workflow_storage(mock_engine, mock_session): + """Create a PostgresStorage instance for workflow mode with mocked components.""" + with patch("agno.storage.postgres.scoped_session", return_value=mock_session[0]): + with patch("agno.storage.postgres.inspect", return_value=MagicMock()): + storage = PostgresStorage( + table_name="workflow_sessions", schema="ai", db_engine=mock_engine, mode="workflow" + ) + # Mock table_exists to return True + storage.table_exists = MagicMock(return_value=True) + return storage, mock_session[1] + + +def test_agent_storage_initialization(): + """Test PostgresStorage initialization with different parameters.""" + # Test with db_url + with patch("agno.storage.postgres.create_engine") as mock_create_engine: + with patch("agno.storage.postgres.scoped_session"): + with patch("agno.storage.postgres.inspect"): + mock_engine = MagicMock() + mock_create_engine.return_value = mock_engine + + storage = PostgresStorage(table_name="test_table", db_url="postgresql://user:pass@localhost/db") + + mock_create_engine.assert_called_once_with("postgresql://user:pass@localhost/db") + assert storage.table_name == "test_table" + assert storage.schema == "ai" # Default value + assert storage.mode == "agent" # Default value + + # Test with missing db_url and db_engine + with pytest.raises(ValueError, match="Must provide either db_url or db_engine"): + PostgresStorage(table_name="test_table") + + +def test_agent_storage_crud(agent_storage): + """Test CRUD operations for agent storage.""" + storage, mock_session = agent_storage + + # Create a test session + session = AgentSession( + session_id="test-session", + agent_id="test-agent", + user_id="test-user", + memory={"key": "value"}, + agent_data={"name": "Test Agent"}, + session_data={"state": "active"}, + extra_data={"custom": "data"}, + ) + + # Instead of mocking side_effect, directly mock the return value for upsert + # This simulates the behavior we want without relying on the internal implementation + original_read = storage.read + storage.read = MagicMock(return_value=None) # For initial read check + + # Mock upsert to return the session directly + original_upsert = storage.upsert + storage.upsert = MagicMock(return_value=session) + + # Test upsert + result = storage.upsert(session) + assert result == session + + # Restore original methods for other tests + storage.read = original_read + storage.upsert = original_upsert + + # Now test read with a direct mock + storage.read = MagicMock(return_value=session) + read_result = storage.read("test-session") + assert read_result == session + + # Test delete + storage.delete_session("test-session") + mock_session.execute.assert_called() + + +def test_workflow_storage_crud(workflow_storage): + """Test CRUD operations for workflow storage.""" + storage, mock_session = workflow_storage + + # Create a test session + session = WorkflowSession( + session_id="test-session", + workflow_id="test-workflow", + user_id="test-user", + memory={"key": "value"}, + workflow_data={"name": "Test Workflow"}, + session_data={"state": "active"}, + extra_data={"custom": "data"}, + ) + + # Instead of mocking side_effect, directly mock the return value for upsert + original_read = storage.read + storage.read = MagicMock(return_value=None) # For initial read check + + # Mock upsert to return the session directly + original_upsert = storage.upsert + storage.upsert = MagicMock(return_value=session) + + # Test upsert + result = storage.upsert(session) + assert result == session + + # Restore original methods for other tests + storage.read = original_read + storage.upsert = original_upsert + + # Now test read with a direct mock + storage.read = MagicMock(return_value=session) + read_result = storage.read("test-session") + assert read_result == session + + # Test delete + storage.delete_session("test-session") + mock_session.execute.assert_called() + + +def test_get_all_sessions(agent_storage): + """Test retrieving all sessions.""" + storage, mock_session = agent_storage + + # Create mock sessions + sessions = [ + AgentSession( + session_id=f"session-{i}", + agent_id=f"agent-{i % 2 + 1}", + user_id=f"user-{i % 2 + 1}", + ) + for i in range(4) + ] + + # Mock the fetchall result + mock_result = MagicMock() + mock_result.fetchall.return_value = [MagicMock(_mapping=session.to_dict()) for session in sessions] + mock_session.execute.return_value = mock_result + + # Test get_all_sessions + result = storage.get_all_sessions() + assert len(result) == 4 + + # Test filtering by user_id + mock_session.execute.reset_mock() + mock_result.fetchall.return_value = [ + MagicMock(_mapping=session.to_dict()) for session in sessions if session.user_id == "user-1" + ] + mock_session.execute.return_value = mock_result + + result = storage.get_all_sessions(user_id="user-1") + assert len(result) == 2 + assert all(s.user_id == "user-1" for s in result) + + # Test filtering by agent_id + mock_session.execute.reset_mock() + mock_result.fetchall.return_value = [ + MagicMock(_mapping=session.to_dict()) for session in sessions if session.agent_id == "agent-1" + ] + mock_session.execute.return_value = mock_result + + result = storage.get_all_sessions(entity_id="agent-1") + assert len(result) == 2 + assert all(s.agent_id == "agent-1" for s in result) + + +def test_get_all_session_ids(agent_storage): + """Test retrieving all session IDs.""" + storage, mock_session = agent_storage + + # Mock the fetchall result + mock_result = MagicMock() + mock_result.fetchall.return_value = [("session-1",), ("session-2",), ("session-3",)] + mock_session.execute.return_value = mock_result + + # Test get_all_session_ids + result = storage.get_all_session_ids() + assert result == ["session-1", "session-2", "session-3"] + + +def test_table_exists(agent_storage): + """Test the table_exists method.""" + storage, mock_session = agent_storage + + # Test when table exists + mock_scalar = MagicMock(return_value=1) + mock_session.execute.return_value.scalar = mock_scalar + + # Reset the mocked table_exists + storage.table_exists = PostgresStorage.table_exists.__get__(storage) + + assert storage.table_exists() is True + + # Test when table doesn't exist + mock_scalar = MagicMock(return_value=None) + mock_session.execute.return_value.scalar = mock_scalar + + assert storage.table_exists() is False + + +def test_create_table(agent_storage): + """Test table creation.""" + storage, mock_session = agent_storage + + # Reset the mocked table_exists + storage.table_exists = MagicMock(return_value=False) + + # Mock the create method + with patch.object(storage.table, "create"): + storage.create() + mock_session.execute.assert_called() # For schema creation + # The actual table creation is more complex with indexes, so we don't verify all details + + +def test_drop_table(agent_storage): + """Test dropping a table.""" + storage, mock_session = agent_storage + + # Mock table_exists to return True + storage.table_exists = MagicMock(return_value=True) + + # Mock the drop method + with patch.object(storage.table, "drop") as mock_drop: + storage.drop() + mock_drop.assert_called_once_with(storage.db_engine, checkfirst=True) + + +def test_mode_switching(): + """Test switching between agent and workflow modes.""" + with patch("agno.storage.postgres.scoped_session"): + with patch("agno.storage.postgres.inspect"): + with patch("agno.storage.postgres.create_engine"): + # Create storage in agent mode + storage = PostgresStorage(table_name="test_table", db_url="postgresql://user:pass@localhost/db") + assert storage.mode == "agent" + + # Switch to workflow mode + with patch.object(storage, "get_table") as mock_get_table: + storage.mode = "workflow" + assert storage.mode == "workflow" + mock_get_table.assert_called_once() diff --git a/libs/agno/tests/unit/storage/test_singlestore_storage.py b/libs/agno/tests/unit/storage/test_singlestore_storage.py new file mode 100644 index 0000000000..0da15a58df --- /dev/null +++ b/libs/agno/tests/unit/storage/test_singlestore_storage.py @@ -0,0 +1,327 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from agno.storage.session.agent import AgentSession +from agno.storage.session.workflow import WorkflowSession +from agno.storage.singlestore import SingleStoreStorage + + +@pytest.fixture +def mock_engine(): + """Create a mock SQLAlchemy engine.""" + engine = MagicMock() + return engine + + +@pytest.fixture +def mock_session(): + """Create a mock SQLAlchemy session.""" + session_factory = MagicMock() + session_instance = MagicMock() + + # Set up the context manager behavior + context_manager = MagicMock() + context_manager.__enter__ = MagicMock(return_value=session_instance) + context_manager.__exit__ = MagicMock(return_value=None) + + # Make the session factory's begin() return the context manager + session_factory.begin = MagicMock(return_value=context_manager) + + return session_factory, session_instance + + +@pytest.fixture +def agent_storage(mock_engine, mock_session): + """Create a SingleStoreStorage instance for agent mode with mocked components.""" + session_factory, session_instance = mock_session + with patch("agno.storage.singlestore.sessionmaker", return_value=session_factory): + with patch("agno.storage.singlestore.inspect", return_value=MagicMock()): + storage = SingleStoreStorage(table_name="agent_sessions", schema="ai", db_engine=mock_engine, mode="agent") + # Mock table_exists to return True + storage.table_exists = MagicMock(return_value=True) + return storage, session_instance + + +@pytest.fixture +def workflow_storage(mock_engine, mock_session): + """Create a SingleStoreStorage instance for workflow mode with mocked components.""" + session_factory, session_instance = mock_session + with patch("agno.storage.singlestore.sessionmaker", return_value=session_factory): + with patch("agno.storage.singlestore.inspect", return_value=MagicMock()): + storage = SingleStoreStorage( + table_name="workflow_sessions", schema="ai", db_engine=mock_engine, mode="workflow" + ) + # Mock table_exists to return True + storage.table_exists = MagicMock(return_value=True) + return storage, session_instance + + +def test_initialization(): + """Test SingleStoreStorage initialization with different parameters.""" + # Test with db_url + with patch("agno.storage.singlestore.create_engine") as mock_create_engine: + with patch("agno.storage.singlestore.sessionmaker"): + with patch("agno.storage.singlestore.inspect"): + mock_engine = MagicMock() + mock_create_engine.return_value = mock_engine + + storage = SingleStoreStorage(table_name="test_table", db_url="mysql://user:pass@localhost/db") + + mock_create_engine.assert_called_once_with( + "mysql://user:pass@localhost/db", connect_args={"charset": "utf8mb4"} + ) + assert storage.table_name == "test_table" + assert storage.schema == "ai" # Default value + assert storage.mode == "agent" # Default value + + # Test with missing db_url and db_engine + with pytest.raises(ValueError, match="Must provide either db_url or db_engine"): + SingleStoreStorage(table_name="test_table") + + +def test_agent_storage_crud(agent_storage): + """Test CRUD operations for agent storage.""" + storage, mock_session = agent_storage + + # Create a test session + session = AgentSession( + session_id="test-session", + agent_id="test-agent", + user_id="test-user", + memory={"key": "value"}, + agent_data={"name": "Test Agent"}, + session_data={"state": "active"}, + extra_data={"custom": "data"}, + ) + + # Mock the read method to return None initially (for checking if exists) + # and then return the session after upsert + original_read = storage.read + storage.read = MagicMock(return_value=None) + + # Mock upsert to return the session directly + original_upsert = storage.upsert + storage.upsert = MagicMock(return_value=session) + + # Test upsert + result = storage.upsert(session) + assert result == session + + # Restore original methods + storage.read = original_read + storage.upsert = original_upsert + + # Now test read with a direct mock + storage.read = MagicMock(return_value=session) + read_result = storage.read("test-session") + assert read_result == session + + # Test delete by mocking the delete_session method directly + storage.delete_session = MagicMock() + storage.delete_session("test-session") + storage.delete_session.assert_called_once_with("test-session") + + +def test_workflow_storage_crud(workflow_storage): + """Test CRUD operations for workflow storage.""" + storage, mock_session = workflow_storage + + # Create a test session + session = WorkflowSession( + session_id="test-session", + workflow_id="test-workflow", + user_id="test-user", + memory={"key": "value"}, + workflow_data={"name": "Test Workflow"}, + session_data={"state": "active"}, + extra_data={"custom": "data"}, + ) + + # Mock the read method to return None initially (for checking if exists) + # and then return the session after upsert + original_read = storage.read + storage.read = MagicMock(return_value=None) + + # Mock upsert to return the session directly + original_upsert = storage.upsert + storage.upsert = MagicMock(return_value=session) + + # Test upsert + result = storage.upsert(session) + assert result == session + + # Restore original methods + storage.read = original_read + storage.upsert = original_upsert + + # Now test read with a direct mock + storage.read = MagicMock(return_value=session) + read_result = storage.read("test-session") + assert read_result == session + + # Test delete by mocking the delete_session method directly + storage.delete_session = MagicMock() + storage.delete_session("test-session") + storage.delete_session.assert_called_once_with("test-session") + + +def test_get_all_sessions(agent_storage): + """Test retrieving all sessions.""" + storage, mock_session = agent_storage + + # Create mock sessions with proper _mapping attribute + mock_rows = [] + for i in range(4): + mock_row = MagicMock() + session_data = { + "session_id": f"session-{i}", + "agent_id": f"agent-{i % 2 + 1}", + "user_id": f"user-{i % 2 + 1}", + "memory": {}, + "agent_data": {}, + "session_data": {}, + "extra_data": {}, + "created_at": 1000000, + "updated_at": None, + } + mock_row._mapping = session_data + mock_row.session_id = session_data["session_id"] + mock_rows.append(mock_row) + + # Mock the execute result + mock_result = MagicMock() + mock_result.fetchall.return_value = mock_rows + mock_session.execute.return_value = mock_result + + # Test get_all_sessions without filters + result = storage.get_all_sessions() + assert len(result) == 4 + assert all(isinstance(s, AgentSession) for s in result) + + # Reset mock for user_id filter test + mock_session.reset_mock() + mock_rows_filtered = [row for row in mock_rows if row._mapping["user_id"] == "user-1"] + mock_result = MagicMock() + mock_result.fetchall.return_value = mock_rows_filtered + mock_session.execute.return_value = mock_result + + result = storage.get_all_sessions(user_id="user-1") + assert len(result) == 2 + assert all(s.user_id == "user-1" for s in result) + + # Reset mock for agent_id filter test + mock_session.reset_mock() + mock_rows_filtered = [row for row in mock_rows if row._mapping["agent_id"] == "agent-1"] + mock_result = MagicMock() + mock_result.fetchall.return_value = mock_rows_filtered + mock_session.execute.return_value = mock_result + + result = storage.get_all_sessions(entity_id="agent-1") + assert len(result) == 2 + assert all(s.agent_id == "agent-1" for s in result) + + +def test_get_all_session_ids(agent_storage): + """Test retrieving all session IDs.""" + storage, mock_session = agent_storage + + # Create mock rows with session_id attribute + mock_rows = [] + for i in range(3): + mock_row = MagicMock() + mock_row.session_id = f"session-{i + 1}" + mock_rows.append(mock_row) + + # Mock the execute result + mock_result = MagicMock() + mock_result.fetchall.return_value = mock_rows + mock_session.execute.return_value = mock_result + + # Test get_all_session_ids without filters + result = storage.get_all_session_ids() + assert result == ["session-1", "session-2", "session-3"] + assert mock_session.execute.called + + # Reset mock for user_id filter test + mock_session.reset_mock() + mock_rows_filtered = mock_rows[:2] # Only return first two sessions + mock_result = MagicMock() + mock_result.fetchall.return_value = mock_rows_filtered + mock_session.execute.return_value = mock_result + + result = storage.get_all_session_ids(user_id="test-user") + assert result == ["session-1", "session-2"] + assert mock_session.execute.called + + # Reset mock for entity_id filter test + mock_session.reset_mock() + mock_rows_filtered = mock_rows[2:] # Only return last session + mock_result = MagicMock() + mock_result.fetchall.return_value = mock_rows_filtered + mock_session.execute.return_value = mock_result + + result = storage.get_all_session_ids(entity_id="test-agent") + assert result == ["session-3"] + assert mock_session.execute.called + + +def test_table_exists(agent_storage): + """Test the table_exists method.""" + storage, _ = agent_storage + + # Test when table exists + with patch("agno.storage.singlestore.inspect") as mock_inspect: + mock_inspect.return_value.has_table.return_value = True + + # Reset the mocked table_exists + storage.table_exists = SingleStoreStorage.table_exists.__get__(storage) + + assert storage.table_exists() is True + + # Test when table doesn't exist + mock_inspect.return_value.has_table.return_value = False + + assert storage.table_exists() is False + + +def test_create_table(agent_storage): + """Test table creation.""" + storage, _ = agent_storage + + # Reset the mocked table_exists + storage.table_exists = MagicMock(return_value=False) + + # Mock the create method + with patch.object(storage.table, "create") as mock_create: + storage.create() + mock_create.assert_called_once_with(storage.db_engine) + + +def test_drop_table(agent_storage): + """Test dropping a table.""" + storage, _ = agent_storage + + # Mock table_exists to return True + storage.table_exists = MagicMock(return_value=True) + + # Mock the drop method + with patch.object(storage.table, "drop") as mock_drop: + storage.drop() + mock_drop.assert_called_once_with(storage.db_engine) + + +def test_mode_switching(): + """Test switching between agent and workflow modes.""" + with patch("agno.storage.singlestore.sessionmaker"): + with patch("agno.storage.singlestore.inspect"): + with patch("agno.storage.singlestore.create_engine"): + # Create storage in agent mode + storage = SingleStoreStorage(table_name="test_table", db_url="mysql://user:pass@localhost/db") + assert storage.mode == "agent" + + # Switch to workflow mode + with patch.object(storage, "get_table") as mock_get_table: + storage.mode = "workflow" + assert storage.mode == "workflow" + mock_get_table.assert_called_once() diff --git a/libs/agno/tests/unit/storage/test_sqlite_storage.py b/libs/agno/tests/unit/storage/test_sqlite_storage.py new file mode 100644 index 0000000000..687491184b --- /dev/null +++ b/libs/agno/tests/unit/storage/test_sqlite_storage.py @@ -0,0 +1,195 @@ +import os +import tempfile +from pathlib import Path +from typing import Generator + +import pytest + +from agno.storage.session.agent import AgentSession +from agno.storage.session.workflow import WorkflowSession +from agno.storage.sqlite import SqliteStorage + + +@pytest.fixture +def temp_db_path() -> Generator[Path, None, None]: + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: + db_path = Path(f.name) + yield db_path + if db_path.exists(): + os.unlink(db_path) + + +@pytest.fixture +def agent_storage(temp_db_path: Path) -> SqliteStorage: + return SqliteStorage(table_name="agent_sessions", db_file=str(temp_db_path), mode="agent") + + +@pytest.fixture +def workflow_storage(temp_db_path: Path) -> SqliteStorage: + return SqliteStorage(table_name="workflow_sessions", db_file=str(temp_db_path), mode="workflow") + + +def test_agent_storage_crud(agent_storage: SqliteStorage): + # Test create + agent_storage.create() + assert agent_storage.table_exists() + + # Test upsert + session = AgentSession( + session_id="test-session", + agent_id="test-agent", + user_id="test-user", + memory={"key": "value"}, + agent_data={"name": "Test Agent"}, + session_data={"state": "active"}, + extra_data={"custom": "data"}, + ) + + saved_session = agent_storage.upsert(session) + assert saved_session is not None + assert saved_session.session_id == session.session_id + + # Test read + read_session = agent_storage.read("test-session") + assert read_session is not None + assert read_session.session_id == session.session_id + assert read_session.agent_id == session.agent_id + assert read_session.memory == session.memory + + # Test get all sessions + all_sessions = agent_storage.get_all_sessions() + assert len(all_sessions) == 1 + assert all_sessions[0].session_id == session.session_id + + # Test delete + agent_storage.delete_session("test-session") + assert agent_storage.read("test-session") is None + + # Test drop + agent_storage.drop() + assert not agent_storage.table_exists() + + +def test_workflow_storage_crud(workflow_storage: SqliteStorage): + # Test create + workflow_storage.create() + assert workflow_storage.table_exists() + + # Test upsert + session = WorkflowSession( + session_id="test-session", + workflow_id="test-workflow", + user_id="test-user", + memory={"key": "value"}, + workflow_data={"name": "Test Workflow"}, + session_data={"state": "active"}, + extra_data={"custom": "data"}, + ) + + saved_session = workflow_storage.upsert(session) + assert saved_session is not None + assert saved_session.session_id == session.session_id + + # Test read + read_session = workflow_storage.read("test-session") + assert read_session is not None + assert read_session.session_id == session.session_id + assert read_session.workflow_id == session.workflow_id + assert read_session.memory == session.memory + + # Test get all sessions + all_sessions = workflow_storage.get_all_sessions() + assert len(all_sessions) == 1 + assert all_sessions[0].session_id == session.session_id + + # Test delete + workflow_storage.delete_session("test-session") + assert workflow_storage.read("test-session") is None + + # Test drop + workflow_storage.drop() + assert not workflow_storage.table_exists() + + +def test_storage_filtering(agent_storage: SqliteStorage): + # Create test sessions with different combinations + sessions = [ + AgentSession( + session_id=f"session-{i}", + agent_id=f"agent-{i // 2 + 1}", # agent-1, agent-1, agent-2, agent-2 + user_id=f"user-{i % 3 + 1}", # user-1, user-2, user-3, user-1 + memory={"test": f"memory-{i}"}, + agent_data={"name": f"Agent {i}"}, + session_data={"state": "active"}, + ) + for i in range(4) + ] + + for session in sessions: + agent_storage.upsert(session) + + # Test filtering by user_id + for user_id in ["user-1", "user-2", "user-3"]: + user_sessions = agent_storage.get_all_sessions(user_id=user_id) + assert all(s.user_id == user_id for s in user_sessions) + + # Test filtering by agent_id + for agent_id in ["agent-1", "agent-2"]: + agent_sessions = agent_storage.get_all_sessions(entity_id=agent_id) + assert all(s.agent_id == agent_id for s in agent_sessions) + assert len(agent_sessions) == 2 # Each agent has 2 sessions + + # Test combined filtering + filtered_sessions = agent_storage.get_all_sessions(user_id="user-1", entity_id="agent-1") + assert len(filtered_sessions) == 1 + assert filtered_sessions[0].user_id == "user-1" + assert filtered_sessions[0].agent_id == "agent-1" + + # Test filtering with non-existent IDs + empty_sessions = agent_storage.get_all_sessions(user_id="non-existent") + assert len(empty_sessions) == 0 + + empty_sessions = agent_storage.get_all_sessions(entity_id="non-existent") + assert len(empty_sessions) == 0 + + +def test_workflow_storage_filtering(workflow_storage: SqliteStorage): + # Create test sessions with different combinations + sessions = [ + WorkflowSession( + session_id=f"session-{i}", + workflow_id=f"workflow-{i // 2 + 1}", # workflow-1, workflow-1, workflow-2, workflow-2 + user_id=f"user-{i % 3 + 1}", # user-1, user-2, user-3, user-1 + memory={"test": f"memory-{i}"}, + workflow_data={"name": f"Workflow {i}"}, + session_data={"state": "active"}, + ) + for i in range(4) + ] + + for session in sessions: + workflow_storage.upsert(session) + + # Test filtering by user_id + for user_id in ["user-1", "user-2", "user-3"]: + user_sessions = workflow_storage.get_all_sessions(user_id=user_id) + assert all(s.user_id == user_id for s in user_sessions) + + # Test filtering by workflow_id + for workflow_id in ["workflow-1", "workflow-2"]: + workflow_sessions = workflow_storage.get_all_sessions(entity_id=workflow_id) + assert all(s.workflow_id == workflow_id for s in workflow_sessions) + assert len(workflow_sessions) == 2 # Each workflow has 2 sessions + + # Test combined filtering + filtered_sessions = workflow_storage.get_all_sessions(user_id="user-1", entity_id="workflow-1") + assert len(filtered_sessions) == 1 + assert filtered_sessions[0].user_id == "user-1" + assert filtered_sessions[0].workflow_id == "workflow-1" + + # Test filtering with non-existent IDs + empty_sessions = workflow_storage.get_all_sessions(user_id="non-existent") + assert len(empty_sessions) == 0 + + empty_sessions = workflow_storage.get_all_sessions(entity_id="non-existent") + assert len(empty_sessions) == 0 diff --git a/libs/agno/tests/unit/storage/test_yaml_storage.py b/libs/agno/tests/unit/storage/test_yaml_storage.py new file mode 100644 index 0000000000..129db3d2f8 --- /dev/null +++ b/libs/agno/tests/unit/storage/test_yaml_storage.py @@ -0,0 +1,176 @@ +import tempfile +from pathlib import Path +from typing import Generator + +import pytest + +from agno.storage.session.agent import AgentSession +from agno.storage.session.workflow import WorkflowSession +from agno.storage.yaml import YamlStorage + + +@pytest.fixture +def temp_dir() -> Generator[Path, None, None]: + with tempfile.TemporaryDirectory() as temp_dir: + yield Path(temp_dir) + + +@pytest.fixture +def agent_storage(temp_dir: Path) -> YamlStorage: + return YamlStorage(dir_path=temp_dir, mode="agent") + + +@pytest.fixture +def workflow_storage(temp_dir: Path) -> YamlStorage: + return YamlStorage(dir_path=temp_dir, mode="workflow") + + +def test_agent_storage_crud(agent_storage: YamlStorage, temp_dir: Path): + # Test create + agent_storage.create() + assert temp_dir.exists() + + # Test upsert + session = AgentSession( + session_id="test-session", + agent_id="test-agent", + user_id="test-user", + memory={"key": "value"}, + agent_data={"name": "Test Agent"}, + session_data={"state": "active"}, + extra_data={"custom": "data"}, + ) + + saved_session = agent_storage.upsert(session) + assert saved_session is not None + assert saved_session.session_id == session.session_id + assert (temp_dir / "test-session.yaml").exists() + + # Test read + read_session = agent_storage.read("test-session") + assert read_session is not None + assert read_session.session_id == session.session_id + assert read_session.agent_id == session.agent_id + assert read_session.memory == session.memory + + # Test get all sessions + all_sessions = agent_storage.get_all_sessions() + assert len(all_sessions) == 1 + assert all_sessions[0].session_id == session.session_id + + # Test delete + agent_storage.delete_session("test-session") + assert agent_storage.read("test-session") is None + assert not (temp_dir / "test-session.yaml").exists() + + +def test_workflow_storage_crud(workflow_storage: YamlStorage, temp_dir: Path): + # Test create + workflow_storage.create() + assert temp_dir.exists() + + # Test upsert + session = WorkflowSession( + session_id="test-session", + workflow_id="test-workflow", + user_id="test-user", + memory={"key": "value"}, + workflow_data={"name": "Test Workflow"}, + session_data={"state": "active"}, + extra_data={"custom": "data"}, + ) + + saved_session = workflow_storage.upsert(session) + assert saved_session is not None + assert saved_session.session_id == session.session_id + assert (temp_dir / "test-session.yaml").exists() + + # Test read + read_session = workflow_storage.read("test-session") + assert read_session is not None + assert read_session.session_id == session.session_id + assert read_session.workflow_id == session.workflow_id + assert read_session.memory == session.memory + + # Test get all sessions + all_sessions = workflow_storage.get_all_sessions() + assert len(all_sessions) == 1 + assert all_sessions[0].session_id == session.session_id + + # Test delete + workflow_storage.delete_session("test-session") + assert workflow_storage.read("test-session") is None + assert not (temp_dir / "test-session.yaml").exists() + + +def test_storage_filtering(agent_storage: YamlStorage): + # Create test sessions + sessions = [ + AgentSession( + session_id=f"session-{i}", + agent_id="agent-1" if i < 2 else "agent-2", + user_id="user-1" if i % 2 == 0 else "user-2", + ) + for i in range(4) + ] + + for session in sessions: + agent_storage.upsert(session) + + # Test filtering by user_id + user1_sessions = agent_storage.get_all_sessions(user_id="user-1") + assert len(user1_sessions) == 2 + assert all(s.user_id == "user-1" for s in user1_sessions) + + # Test filtering by agent_id + agent1_sessions = agent_storage.get_all_sessions(entity_id="agent-1") + assert len(agent1_sessions) == 2 + assert all(s.agent_id == "agent-1" for s in agent1_sessions) + + # Test combined filtering + filtered_sessions = agent_storage.get_all_sessions(user_id="user-1", entity_id="agent-1") + assert len(filtered_sessions) == 1 + assert filtered_sessions[0].user_id == "user-1" + assert filtered_sessions[0].agent_id == "agent-1" + + +def test_workflow_storage_filtering(workflow_storage: YamlStorage): + # Create test sessions + sessions = [ + WorkflowSession( + session_id=f"session-{i}", + workflow_id="workflow-1" if i < 2 else "workflow-2", + user_id="user-1" if i % 2 == 0 else "user-2", + memory={"key": f"value-{i}"}, + workflow_data={"name": f"Test Workflow {i}"}, + session_data={"state": "active"}, + extra_data={"custom": f"data-{i}"}, + ) + for i in range(4) + ] + + for session in sessions: + workflow_storage.upsert(session) + + # Test filtering by user_id + user1_sessions = workflow_storage.get_all_sessions(user_id="user-1") + assert len(user1_sessions) == 2 + assert all(s.user_id == "user-1" for s in user1_sessions) + + # Test filtering by workflow_id + workflow1_sessions = workflow_storage.get_all_sessions(entity_id="workflow-1") + assert len(workflow1_sessions) == 2 + assert all(s.workflow_id == "workflow-1" for s in workflow1_sessions) + + # Test combined filtering + filtered_sessions = workflow_storage.get_all_sessions(user_id="user-1", entity_id="workflow-1") + assert len(filtered_sessions) == 1 + assert filtered_sessions[0].user_id == "user-1" + assert filtered_sessions[0].workflow_id == "workflow-1" + + # Test filtering with non-existent IDs + empty_sessions = workflow_storage.get_all_sessions(user_id="non-existent") + assert len(empty_sessions) == 0 + + empty_sessions = workflow_storage.get_all_sessions(entity_id="non-existent") + assert len(empty_sessions) == 0 diff --git a/libs/agno/tests/unit/tools/test_calculator.py b/libs/agno/tests/unit/tools/test_calculator.py new file mode 100644 index 0000000000..0d76631bab --- /dev/null +++ b/libs/agno/tests/unit/tools/test_calculator.py @@ -0,0 +1,293 @@ +"""Unit tests for CalculatorTools class.""" + +import json +from unittest.mock import patch + +import pytest + +from agno.tools.calculator import CalculatorTools + + +@pytest.fixture +def calculator_tools(): + """Create a CalculatorTools instance with all operations enabled.""" + return CalculatorTools(enable_all=True) + + +@pytest.fixture +def basic_calculator_tools(): + """Create a CalculatorTools instance with only basic operations.""" + return CalculatorTools() + + +def test_initialization_with_selective_operations(): + """Test initialization with only selected operations.""" + # Only enable specific operations + tools = CalculatorTools( + add=True, + subtract=True, + multiply=False, + divide=False, + exponentiate=True, + factorial=False, + is_prime=True, + square_root=False, + ) + + # Check which functions are registered + function_names = [func.name for func in tools.functions.values()] + + assert "add" in function_names + assert "subtract" in function_names + assert "multiply" not in function_names + assert "divide" not in function_names + assert "exponentiate" in function_names + assert "factorial" not in function_names + assert "is_prime" in function_names + assert "square_root" not in function_names + + +def test_initialization_with_all_operations(): + """Test initialization with all operations enabled.""" + tools = CalculatorTools(enable_all=True) + + function_names = [func.name for func in tools.functions.values()] + + assert "add" in function_names + assert "subtract" in function_names + assert "multiply" in function_names + assert "divide" in function_names + assert "exponentiate" in function_names + assert "factorial" in function_names + assert "is_prime" in function_names + assert "square_root" in function_names + + +def test_add_operation(calculator_tools): + """Test addition operation.""" + result = calculator_tools.add(5, 3) + result_data = json.loads(result) + + assert result_data["operation"] == "addition" + assert result_data["result"] == 8 + + # Test with negative numbers + result = calculator_tools.add(-5, 3) + result_data = json.loads(result) + assert result_data["result"] == -2 + + # Test with floating point numbers + result = calculator_tools.add(5.5, 3.2) + result_data = json.loads(result) + assert result_data["result"] == 8.7 + + +def test_subtract_operation(calculator_tools): + """Test subtraction operation.""" + result = calculator_tools.subtract(5, 3) + result_data = json.loads(result) + + assert result_data["operation"] == "subtraction" + assert result_data["result"] == 2 + + # Test with negative numbers + result = calculator_tools.subtract(-5, 3) + result_data = json.loads(result) + assert result_data["result"] == -8 + + # Test with floating point numbers + result = calculator_tools.subtract(5.5, 3.2) + result_data = json.loads(result) + assert result_data["result"] == 2.3 + + +def test_multiply_operation(calculator_tools): + """Test multiplication operation.""" + result = calculator_tools.multiply(5, 3) + result_data = json.loads(result) + + assert result_data["operation"] == "multiplication" + assert result_data["result"] == 15 + + # Test with negative numbers + result = calculator_tools.multiply(-5, 3) + result_data = json.loads(result) + assert result_data["result"] == -15 + + # Test with floating point numbers + result = calculator_tools.multiply(5.5, 3.2) + result_data = json.loads(result) + assert result_data["result"] == 17.6 + + +def test_divide_operation(calculator_tools): + """Test division operation.""" + result = calculator_tools.divide(6, 3) + result_data = json.loads(result) + + assert result_data["operation"] == "division" + assert result_data["result"] == 2 + + # Test with floating point result + result = calculator_tools.divide(5, 2) + result_data = json.loads(result) + assert result_data["result"] == 2.5 + + # Test division by zero + result = calculator_tools.divide(5, 0) + result_data = json.loads(result) + assert "error" in result_data + assert "Division by zero is undefined" in result_data["error"] + + +def test_exponentiate_operation(calculator_tools): + """Test exponentiation operation.""" + result = calculator_tools.exponentiate(2, 3) + result_data = json.loads(result) + + assert result_data["operation"] == "exponentiation" + assert result_data["result"] == 8 + + # Test with negative exponent + result = calculator_tools.exponentiate(2, -2) + result_data = json.loads(result) + assert result_data["result"] == 0.25 + + # Test with floating point numbers + result = calculator_tools.exponentiate(2.5, 2) + result_data = json.loads(result) + assert result_data["result"] == 6.25 + + +def test_factorial_operation(calculator_tools): + """Test factorial operation.""" + result = calculator_tools.factorial(5) + result_data = json.loads(result) + + assert result_data["operation"] == "factorial" + assert result_data["result"] == 120 + + # Test with zero + result = calculator_tools.factorial(0) + result_data = json.loads(result) + assert result_data["result"] == 1 + + # Test with negative number + result = calculator_tools.factorial(-1) + result_data = json.loads(result) + assert "error" in result_data + assert "Factorial of a negative number is undefined" in result_data["error"] + + +def test_is_prime_operation(calculator_tools): + """Test prime number checking operation.""" + # Test with prime number + result = calculator_tools.is_prime(7) + result_data = json.loads(result) + + assert result_data["operation"] == "prime_check" + assert result_data["result"] is True + + # Test with non-prime number + result = calculator_tools.is_prime(4) + result_data = json.loads(result) + assert result_data["result"] is False + + # Test with 1 (not prime by definition) + result = calculator_tools.is_prime(1) + result_data = json.loads(result) + assert result_data["result"] is False + + # Test with 0 (not prime) + result = calculator_tools.is_prime(0) + result_data = json.loads(result) + assert result_data["result"] is False + + # Test with negative number (not prime) + result = calculator_tools.is_prime(-5) + result_data = json.loads(result) + assert result_data["result"] is False + + +def test_square_root_operation(calculator_tools): + """Test square root operation.""" + result = calculator_tools.square_root(9) + result_data = json.loads(result) + + assert result_data["operation"] == "square_root" + assert result_data["result"] == 3 + + # Test with non-perfect square + result = calculator_tools.square_root(2) + result_data = json.loads(result) + assert result_data["result"] == pytest.approx(1.4142, 0.0001) + + # Test with negative number + result = calculator_tools.square_root(-1) + result_data = json.loads(result) + assert "error" in result_data + assert "Square root of a negative number is undefined" in result_data["error"] + + +def test_basic_calculator_has_only_basic_operations(basic_calculator_tools): + """Test that basic calculator only has basic operations.""" + function_names = [func.name for func in basic_calculator_tools.functions.values()] + + # Basic operations should be included + assert "add" in function_names + assert "subtract" in function_names + assert "multiply" in function_names + assert "divide" in function_names + + # Advanced operations should not be included + assert "exponentiate" not in function_names + assert "factorial" not in function_names + assert "is_prime" not in function_names + assert "square_root" not in function_names + + +def test_logging(calculator_tools): + """Test that operations are properly logged.""" + with patch("agno.tools.calculator.logger.info") as mock_logger: + calculator_tools.add(5, 3) + mock_logger.assert_called_once_with("Adding 5 and 3 to get 8") + + mock_logger.reset_mock() + calculator_tools.multiply(4, 2) + mock_logger.assert_called_once_with("Multiplying 4 and 2 to get 8") + + +def test_error_logging(calculator_tools): + """Test that errors are properly logged.""" + with patch("agno.tools.calculator.logger.error") as mock_logger: + calculator_tools.divide(5, 0) + mock_logger.assert_called_once_with("Attempt to divide by zero") + + mock_logger.reset_mock() + calculator_tools.factorial(-1) + mock_logger.assert_called_once_with("Attempt to calculate factorial of a negative number") + + mock_logger.reset_mock() + calculator_tools.square_root(-4) + mock_logger.assert_called_once_with("Attempt to calculate square root of a negative number") + + +def test_large_numbers(calculator_tools): + """Test operations with large numbers.""" + # Test factorial with large number + result = calculator_tools.factorial(20) + result_data = json.loads(result) + assert result_data["result"] == 2432902008176640000 + + # Test exponentiation with large numbers + result = calculator_tools.exponentiate(2, 30) + result_data = json.loads(result) + assert result_data["result"] == 1073741824 + + +def test_division_exception_handling(calculator_tools): + """Test handling of exceptions in division.""" + with patch("math.pow", side_effect=Exception("Test exception")): + result = calculator_tools.divide(1, 0) + result_data = json.loads(result) + assert "error" in result_data diff --git a/libs/agno/tests/unit/utils/__init__.py b/libs/agno/tests/unit/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/libs/agno/tests/unit/vectordb/__init__.py b/libs/agno/tests/unit/vectordb/__init__.py new file mode 100644 index 0000000000..e69de29bb2