diff --git a/pyproject.toml b/pyproject.toml index 88c363db1..8edfda029 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ dependencies = [ # Keep sorted!!! "aioboto3>=11.0.0", "authlib~=1.5", + "chromadb~=1.1.0", "click~=8.1", "colorama~=0.4.6", "datasets~=4.0", # workaround for uv's solver choosing different versions of datasets based on sys_platform @@ -37,6 +38,7 @@ dependencies = [ "openinference-semantic-conventions~=0.1.14", "openpyxl~=3.1", "optuna~=4.4.0", + "pandas~=2.0", "pip>=24.3.1", "pkce==1.0.3", "pkginfo~=1.12", @@ -48,6 +50,7 @@ dependencies = [ "rich~=13.9", "tabulate~=0.9", "uvicorn[standard]~=0.34", + "vanna~=0.7.9", "wikipedia~=1.4", ] requires-python = ">=3.11,<3.14" @@ -80,6 +83,7 @@ mcp = ["nvidia-nat-mcp"] mem0ai = ["nvidia-nat-mem0ai"] opentelemetry = ["nvidia-nat-opentelemetry"] phoenix = ["nvidia-nat-phoenix"] +postgres = ["psycopg2-binary~=2.9"] profiling = ["nvidia-nat-profiling"] # meta-package ragaai = ["nvidia-nat-ragaai"] mysql = ["nvidia-nat-mysql"] diff --git a/src/nat/retriever/register.py b/src/nat/retriever/register.py index f1fb81786..a3717cdd9 100644 --- a/src/nat/retriever/register.py +++ b/src/nat/retriever/register.py @@ -19,3 +19,4 @@ # Import any providers which need to be automatically registered here import nat.retriever.milvus.register import nat.retriever.nemo_retriever.register +import nat.retriever.sql_retriever.register diff --git a/src/nat/retriever/sql_retriever/__init__.py b/src/nat/retriever/sql_retriever/__init__.py new file mode 100644 index 000000000..bcf38be34 --- /dev/null +++ b/src/nat/retriever/sql_retriever/__init__.py @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SQL Retriever module for NeMo Agent Toolkit.""" + +from nat.retriever.sql_retriever.sql_retriever import SQLRetriever + +__all__ = ["SQLRetriever"] + diff --git a/src/nat/retriever/sql_retriever/register.py b/src/nat/retriever/sql_retriever/register.py new file mode 100644 index 000000000..03d22c544 --- /dev/null +++ b/src/nat/retriever/sql_retriever/register.py @@ -0,0 +1,123 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Registration module for SQL Retriever.""" + +import logging + +from pydantic import BaseModel +from pydantic import Field + +from nat.builder.builder import Builder +from nat.cli.register_workflow import register_retriever +from nat.data_models.retriever import RetrieverBaseConfig +from nat.retriever.sql_retriever.sql_retriever import SQLRetriever + +logger = logging.getLogger(__name__) + + +class SQLRetrieverConfig(RetrieverBaseConfig, name="sql_retriever"): + """ + Configuration for SQL Retriever. + + This retriever uses Vanna AI with NVIDIA NIM to convert natural language queries + into SQL and retrieve data from SQL databases. + + Supported database types: + - sqlite: SQLite databases (local file-based) + - postgres/postgresql: PostgreSQL databases + - sql: Generic SQL databases via SQLAlchemy (MySQL, SQL Server, Oracle, etc.) + """ + + llm_name: str = Field(description="Name of the LLM to use for SQL generation") + embedding_name: str = Field( + description="Name of the embedding model to use for Vanna training data" + ) + vector_store_path: str = Field( + description="Path to ChromaDB vector store for Vanna training data" + ) + db_connection_string: str = Field( + description=( + "Database connection string. Format depends on db_type:\n" + "- sqlite: Path to .db file (e.g., '/path/to/database.db')\n" + "- postgres: Connection string (e.g., 'postgresql://user:pass@host:port/db')\n" + "- sql: SQLAlchemy connection string (e.g., 'mysql+pymysql://user:pass@host/db')" + ) + ) + db_type: str = Field( + default="sqlite", + description="Type of database: 'sqlite', 'postgres', or 'sql' (generic SQL via SQLAlchemy)", + ) + training_data_path: str | None = Field( + default=None, + description="Path to YAML file containing Vanna training data (DDL, documentation, question-SQL pairs)", + ) + max_results: int = Field( + default=100, + description="Maximum number of results to return from SQL queries", + ) + nvidia_api_key: str | None = Field( + default=None, + description="NVIDIA API key (optional, defaults to NVIDIA_API_KEY environment variable)", + ) + + +@register_retriever(config_type=SQLRetrieverConfig) +async def create_sql_retriever(config: SQLRetrieverConfig, builder: Builder): + """ + Create and register a SQL Retriever instance. + + Args: + config: SQLRetrieverConfig containing all necessary parameters + builder: Builder instance for accessing LLM and embedder configurations + + Returns: + SQLRetriever: Configured SQL retriever instance + + Example YAML configuration: + ```yaml + retrievers: + - name: sql_retriever + type: sql_retriever + llm_name: nim_llm + embedding_name: nim_embeddings + vector_store_path: ./vanna_vector_store + db_connection_string: ./database.db + db_type: sqlite + training_data_path: ./training_data.yaml + max_results: 100 + ``` + """ + logger.info(f"Creating SQL Retriever with config: {config.name}") + + # Get LLM and embedder configurations from builder + llm_config = builder.get_llm_config(config.llm_name) + embedder_config = builder.get_embedder_config(config.embedding_name) + + # Create SQL retriever instance + retriever = SQLRetriever( + llm_config=llm_config, + embedder_config=embedder_config, + vector_store_path=config.vector_store_path, + db_connection_string=config.db_connection_string, + db_type=config.db_type, + training_data_path=config.training_data_path, + nvidia_api_key=config.nvidia_api_key, + max_results=config.max_results, + ) + + logger.info(f"SQL Retriever '{config.name}' created successfully") + return retriever + diff --git a/src/nat/retriever/sql_retriever/sql_retriever.py b/src/nat/retriever/sql_retriever/sql_retriever.py new file mode 100644 index 000000000..2c28f4d5b --- /dev/null +++ b/src/nat/retriever/sql_retriever/sql_retriever.py @@ -0,0 +1,245 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SQL Retriever implementation using Vanna for text-to-SQL generation.""" + +import json +import logging +from typing import Any + +import pandas as pd + +from nat.retriever.interface import Retriever +from nat.retriever.models import Document +from nat.retriever.models import RetrieverError +from nat.retriever.models import RetrieverOutput +from nat.retriever.sql_retriever.vanna_manager import VannaManager + +logger = logging.getLogger(__name__) + + +class SQLRetriever(Retriever): + """ + SQL Retriever that converts natural language queries to SQL and executes them. + + This retriever uses Vanna AI with NVIDIA NIM for text-to-SQL generation. + It supports multiple database types: SQLite, PostgreSQL, and generic SQL databases via SQLAlchemy. + + Example: + >>> from nat.retriever.sql_retriever import SQLRetriever + >>> retriever = SQLRetriever( + ... llm_config=llm_config, + ... embedder_config=embedder_config, + ... vector_store_path="/path/to/vector_store", + ... db_connection_string="/path/to/database.db", + ... db_type="sqlite", + ... training_data_path="/path/to/training_data.yaml" + ... ) + >>> results = await retriever.search("What are the top 10 customers by revenue?") + """ + + def __init__( + self, + llm_config: Any, + embedder_config: Any, + vector_store_path: str, + db_connection_string: str, + db_type: str = "sqlite", + training_data_path: str | None = None, + nvidia_api_key: str | None = None, + max_results: int = 100, + **kwargs, + ): + """ + Initialize the SQL Retriever. + + Args: + llm_config: LLM configuration object with model_name attribute + embedder_config: Embedder configuration object with model_name attribute + vector_store_path: Path to ChromaDB vector store for Vanna training data + db_connection_string: Database connection string: + - SQLite: Path to .db file (e.g., "/path/to/database.db") + - PostgreSQL: Connection string (e.g., "postgresql://user:pass@host:port/db") + - Generic SQL: SQLAlchemy connection string (e.g., "mysql+pymysql://user:pass@host/db") + db_type: Type of database - 'sqlite', 'postgres', or 'sql' (default: 'sqlite') + training_data_path: Path to YAML file containing training data for Vanna + nvidia_api_key: NVIDIA API key (optional, defaults to NVIDIA_API_KEY env var) + max_results: Maximum number of results to return (default: 100) + **kwargs: Additional keyword arguments + """ + self.llm_config = llm_config + self.embedder_config = embedder_config + self.vector_store_path = vector_store_path + self.db_connection_string = db_connection_string + self.db_type = db_type + self.training_data_path = training_data_path + self.nvidia_api_key = nvidia_api_key + self.max_results = max_results + + # Create VannaManager instance + self.vanna_manager = VannaManager.create_with_config( + vanna_llm_config=llm_config, + vanna_embedder_config=embedder_config, + vector_store_path=vector_store_path, + db_connection_string=db_connection_string, + db_type=db_type, + training_data_path=training_data_path, # type: ignore[arg-type] + nvidia_api_key=nvidia_api_key, # type: ignore[arg-type] + ) + + logger.info( + f"SQLRetriever initialized with {db_type} database at {db_connection_string}" + ) + + async def search(self, query: str, **kwargs) -> RetrieverOutput: + """ + Retrieve data from SQL database by converting natural language query to SQL. + + Args: + query: Natural language query to convert to SQL + **kwargs: Additional search parameters: + - top_k: Maximum number of results to return (overrides max_results) + - return_sql: If True, include the generated SQL in metadata (default: True) + + Returns: + RetrieverOutput: Retrieved results with documents containing: + - page_content: JSON string of result rows + - metadata: Contains 'sql' (generated SQL), 'row_count', 'columns' + + Raises: + RetrieverError: If SQL generation or execution fails + """ + try: + logger.info(f"SQLRetriever: Processing query: {query}") + + # Get parameters + top_k = kwargs.get("top_k", self.max_results) + return_sql = kwargs.get("return_sql", True) + + # Get Vanna instance + vn_instance = self.vanna_manager.get_instance() + + # Generate SQL from natural language query + logger.debug("Generating SQL query...") + sql = self.vanna_manager.generate_sql_safe(question=query) + logger.info(f"Generated SQL: {sql}") + + # Check if database is connected + if not vn_instance.run_sql_is_set: + raise RetrieverError( + f"Database is not connected. Cannot execute SQL: {sql}" + ) + + # Execute SQL query + logger.debug("Executing SQL query...") + df = vn_instance.run_sql(sql) + + if df is None: + raise RetrieverError(f"SQL execution returned None for query: {sql}") + + if df.empty: + logger.warning(f"No results found for query: {query}") + empty_metadata: dict[str, Any] = { + "row_count": 0, + "columns": [], + "query": query, + } + if return_sql: + empty_metadata["sql"] = sql + return RetrieverOutput( + results=[ + Document( + page_content=json.dumps([]), + metadata=empty_metadata, + ) + ] + ) + + # Limit results + if top_k and top_k < len(df): + logger.debug(f"Limiting results to top {top_k} rows") + df = df.head(top_k) + + # Convert DataFrame to documents + results = self._dataframe_to_documents( + df=df, + sql=sql if return_sql else None, + query=query, + ) + + logger.info(f"SQLRetriever: Retrieved {len(df)} rows") + return results + + except Exception as e: + logger.error(f"Error in SQLRetriever.search: {e}", exc_info=True) + raise RetrieverError(f"Failed to retrieve data from SQL database: {e}") from e + + def _dataframe_to_documents( + self, df: pd.DataFrame, sql: str | None = None, query: str | None = None + ) -> RetrieverOutput: + """ + Convert a pandas DataFrame to RetrieverOutput format. + + Args: + df: Pandas DataFrame containing query results + sql: Generated SQL query (optional) + query: Original natural language query (optional) + + Returns: + RetrieverOutput: Formatted retriever output + """ + # Convert DataFrame to JSON + results_json = df.to_json(orient="records") + if results_json is None: + results_list = [] + else: + results_list = json.loads(results_json) + + # Create metadata + metadata: dict[str, Any] = { + "row_count": len(df), + "columns": df.columns.tolist(), + } + + if sql is not None: + metadata["sql"] = sql + if query is not None: + metadata["query"] = query + + # Create a single document containing all results + # For better integration, we could also create one document per row + # but that might be overwhelming for large result sets + document = Document( + page_content=json.dumps(results_list, indent=2), + metadata=metadata, + ) + + return RetrieverOutput(results=[document]) + + def get_stats(self) -> dict: + """ + Get statistics about the SQLRetriever instance. + + Returns: + dict: Statistics including VannaManager info + """ + return { + "db_type": self.db_type, + "db_connection": self.db_connection_string, + "vector_store_path": self.vector_store_path, + "max_results": self.max_results, + "vanna_manager": self.vanna_manager.get_stats(), + } + diff --git a/src/nat/retriever/sql_retriever/vanna_manager.py b/src/nat/retriever/sql_retriever/vanna_manager.py new file mode 100644 index 000000000..413c3276b --- /dev/null +++ b/src/nat/retriever/sql_retriever/vanna_manager.py @@ -0,0 +1,517 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""VannaManager - A simplified manager for Vanna instances.""" + +import hashlib +import logging +import os +import threading +from typing import Dict + +from nat.retriever.sql_retriever.vanna_util import NIMVanna +from nat.retriever.sql_retriever.vanna_util import NVIDIAEmbeddingFunction +from nat.retriever.sql_retriever.vanna_util import init_vanna + +logger = logging.getLogger(__name__) + + +class VannaManager: + """ + A simplified singleton manager for Vanna instances. + + Key features: + - Singleton pattern to ensure only one instance per configuration + - Thread-safe operations + - Simple instance management + - Support for multiple database types: SQLite, generic SQL, and PostgreSQL + """ + + _instances: Dict[str, "VannaManager"] = {} + _lock = threading.Lock() + + def __new__(cls, config_key: str): + """Ensure singleton pattern per configuration.""" + with cls._lock: + if config_key not in cls._instances: + logger.debug( + f"VannaManager: Creating new singleton instance for config: {config_key}" + ) + cls._instances[config_key] = super().__new__(cls) + cls._instances[config_key]._initialized = False + else: + logger.debug( + f"VannaManager: Returning existing singleton instance for config: {config_key}" + ) + return cls._instances[config_key] + + def __init__( + self, + config_key: str, + vanna_llm_config=None, + vanna_embedder_config=None, + vector_store_path: str | None = None, + db_connection_string: str | None = None, + db_type: str = "sqlite", + training_data_path: str | None = None, + nvidia_api_key: str | None = None, + ): + """ + Initialize the VannaManager and create Vanna instance immediately if all config is provided. + + Args: + config_key: Unique key for this configuration + vanna_llm_config: LLM configuration object + vanna_embedder_config: Embedder configuration object + vector_store_path: Path to ChromaDB vector store + db_connection_string: Database connection string (path for SQLite, connection string for others) + db_type: Type of database - 'sqlite', 'postgres', or 'sql' (generic SQL with SQLAlchemy) + training_data_path: Path to YAML training data file + nvidia_api_key: NVIDIA API key (optional, can use NVIDIA_API_KEY env var) + """ + if hasattr(self, "_initialized") and self._initialized: + return + + self.config_key = config_key + self.lock = threading.Lock() + + # Store configuration + self.vanna_llm_config = vanna_llm_config + self.vanna_embedder_config = vanna_embedder_config + self.vector_store_path = vector_store_path + self.db_connection_string = db_connection_string + self.db_type = db_type + self.training_data_path = training_data_path + self.nvidia_api_key = nvidia_api_key or os.getenv("NVIDIA_API_KEY") + + # Create and initialize Vanna instance immediately if all required config is provided + self.vanna_instance = None + if all( + [ + vanna_llm_config, + vanna_embedder_config, + vector_store_path, + db_connection_string, + ] + ): + logger.debug("VannaManager: Initializing with immediate Vanna instance creation") + self.vanna_instance = self._create_instance() + else: + if any( + [ + vanna_llm_config, + vanna_embedder_config, + vector_store_path, + db_connection_string, + ] + ): + logger.debug( + "VannaManager: Partial configuration provided, Vanna instance will be created later" + ) + else: + logger.debug( + "VannaManager: No configuration provided, Vanna instance will be created later" + ) + + self._initialized = True + logger.debug(f"VannaManager initialized for config: {config_key}") + + def get_instance( + self, + vanna_llm_config=None, + vanna_embedder_config=None, + vector_store_path: str | None = None, + db_connection_string: str | None = None, + db_type: str | None = None, + training_data_path: str | None = None, + nvidia_api_key: str | None = None, + ) -> NIMVanna: + """ + Get the Vanna instance. If not created during init, create it now with provided parameters. + """ + with self.lock: + if self.vanna_instance is None: + logger.debug("VannaManager: No instance created during init, creating now...") + + # Update configuration with provided parameters + self.vanna_llm_config = vanna_llm_config or self.vanna_llm_config + self.vanna_embedder_config = ( + vanna_embedder_config or self.vanna_embedder_config + ) + self.vector_store_path = vector_store_path or self.vector_store_path + self.db_connection_string = ( + db_connection_string or self.db_connection_string + ) + self.db_type = db_type or self.db_type + self.training_data_path = training_data_path or self.training_data_path + self.nvidia_api_key = nvidia_api_key or self.nvidia_api_key + + if all( + [ + self.vanna_llm_config, + self.vanna_embedder_config, + self.vector_store_path, + self.db_connection_string, + ] + ): + self.vanna_instance = self._create_instance() + else: + raise RuntimeError( + "VannaManager: Missing required configuration parameters" + ) + else: + logger.debug( + f"VannaManager: Returning pre-initialized Vanna instance (ID: {id(self.vanna_instance)})" + ) + + # Show vector store status for pre-initialized instances + try: + if self.vector_store_path and os.path.exists(self.vector_store_path): + list_of_folders = [ + d + for d in os.listdir(self.vector_store_path) + if os.path.isdir(os.path.join(self.vector_store_path, d)) + ] + logger.debug( + f"VannaManager: Vector store contains {len(list_of_folders)} collections/folders" + ) + if list_of_folders: + logger.debug( + f"VannaManager: Vector store folders: {list_of_folders}" + ) + else: + logger.debug("VannaManager: Vector store directory does not exist") + except Exception as e: + logger.warning(f"VannaManager: Could not check vector store status: {e}") + + return self.vanna_instance + + def _create_instance(self) -> NIMVanna: + """ + Create a new Vanna instance using the stored configuration. + """ + # Type guards - these should never be None at this point due to earlier checks + if not all( + [ + self.vanna_llm_config, + self.vanna_embedder_config, + self.vector_store_path, + self.db_connection_string, + ] + ): + raise RuntimeError( + "VannaManager: Cannot create instance without required configuration" + ) + + # Assertions to help type checker understand these are not None + assert self.vanna_llm_config is not None + assert self.vanna_embedder_config is not None + assert self.vector_store_path is not None + assert self.db_connection_string is not None + + logger.info(f"VannaManager: Creating instance for {self.config_key}") + logger.debug(f"VannaManager: Vector store path: {self.vector_store_path}") + logger.debug(f"VannaManager: Database connection: {self.db_connection_string}") + logger.debug(f"VannaManager: Database type: {self.db_type}") + logger.debug(f"VannaManager: Training data path: {self.training_data_path}") + + # Create instance + vn_instance = NIMVanna( + VectorConfig={ + "client": "persistent", + "path": self.vector_store_path, + "embedding_function": NVIDIAEmbeddingFunction( + api_key=self.nvidia_api_key, + model=self.vanna_embedder_config.model_name, + ), + }, + LLMConfig={ + "api_key": self.nvidia_api_key, + "model": self.vanna_llm_config.model_name, + }, + ) + + # Connect to database based on type + logger.debug(f"VannaManager: Connecting to {self.db_type} database...") + if self.db_type == "sqlite": + vn_instance.connect_to_sqlite(self.db_connection_string) + elif self.db_type == "postgres" or self.db_type == "postgresql": + self._connect_to_postgres(vn_instance, self.db_connection_string) + elif self.db_type == "sql": + self._connect_to_sql(vn_instance, self.db_connection_string) + else: + raise ValueError( + f"Unsupported database type: {self.db_type}. " + "Supported types: 'sqlite', 'postgres', 'sql'" + ) + + # Set configuration - allow LLM to see data for database introspection + vn_instance.allow_llm_to_see_data = True + logger.debug("VannaManager: Set allow_llm_to_see_data = True") + + # Initialize if needed (check if vector store is empty) + needs_init = self._needs_initialization() + if needs_init: + logger.info( + "VannaManager: Vector store needs initialization, starting training..." + ) + try: + init_vanna(vn_instance, self.training_data_path) + logger.info("VannaManager: Vector store initialization complete") + except Exception as e: + logger.error(f"VannaManager: Error during initialization: {e}") + raise + else: + logger.debug( + "VannaManager: Vector store already initialized, skipping training" + ) + + logger.info("VannaManager: Instance created successfully") + return vn_instance + + def _connect_to_postgres(self, vn_instance: NIMVanna, connection_string: str): + """ + Connect to a PostgreSQL database. + + Args: + vn_instance: The Vanna instance to connect + connection_string: PostgreSQL connection string in format: + postgresql://user:password@host:port/database + """ + try: + import psycopg2 + from psycopg2.pool import SimpleConnectionPool + + logger.info("Connecting to PostgreSQL database...") + + # Parse connection string if needed + if connection_string.startswith("postgresql://"): + # Use SQLAlchemy-style connection for Vanna + vn_instance.connect_to_postgres(url=connection_string) + else: + # Assume it's a psycopg2 connection string + vn_instance.connect_to_postgres(url=f"postgresql://{connection_string}") + + logger.info("Successfully connected to PostgreSQL database") + except ImportError: + logger.error( + "psycopg2 is required for PostgreSQL connections. " + "Install it with: pip install psycopg2-binary" + ) + raise + except Exception as e: + logger.error(f"Error connecting to PostgreSQL: {e}") + raise + + def _connect_to_sql(self, vn_instance: NIMVanna, connection_string: str): + """ + Connect to a generic SQL database using SQLAlchemy. + + Args: + vn_instance: The Vanna instance to connect + connection_string: SQLAlchemy-compatible connection string, e.g.: + - MySQL: mysql+pymysql://user:password@host:port/database + - PostgreSQL: postgresql://user:password@host:port/database + - SQL Server: mssql+pyodbc://user:password@host:port/database?driver=ODBC+Driver+17+for+SQL+Server + - Oracle: oracle+cx_oracle://user:password@host:port/?service_name=service + """ + try: + from sqlalchemy import create_engine + + logger.info("Connecting to SQL database via SQLAlchemy...") + + # Create SQLAlchemy engine + engine = create_engine(connection_string) + + # Connect Vanna to the database using the engine + vn_instance.connect_to_sqlalchemy(engine) + + logger.info("Successfully connected to SQL database") + except ImportError: + logger.error( + "SQLAlchemy is required for generic SQL connections. " + "Install it with: pip install sqlalchemy" + ) + raise + except Exception as e: + logger.error(f"Error connecting to SQL database: {e}") + raise + + def _needs_initialization(self) -> bool: + """ + Check if the vector store needs initialization by checking if it's empty. + """ + logger.debug("VannaManager: Checking if vector store needs initialization...") + logger.debug(f"VannaManager: Vector store path: {self.vector_store_path}") + + # Type guard - vector_store_path should be set at this point + if self.vector_store_path is None: + logger.warning("VannaManager: Vector store path is None, assuming initialization needed") + return True + + try: + if not os.path.exists(self.vector_store_path): + logger.debug( + "VannaManager: Vector store directory does not exist -> needs initialization" + ) + return True + + # Check if there are any subdirectories (ChromaDB creates subdirectories when data is stored) + list_of_folders = [ + d + for d in os.listdir(self.vector_store_path) + if os.path.isdir(os.path.join(self.vector_store_path, d)) + ] + + logger.debug( + f"VannaManager: Found {len(list_of_folders)} folders in vector store" + ) + if list_of_folders: + logger.debug(f"VannaManager: Vector store folders: {list_of_folders}") + logger.debug( + "VannaManager: Vector store is populated -> skipping initialization" + ) + return False + else: + logger.debug("VannaManager: Vector store is empty -> needs initialization") + return True + + except Exception as e: + logger.warning(f"VannaManager: Could not check vector store status: {e}") + logger.warning("VannaManager: Defaulting to needs initialization = True") + return True + + def generate_sql_safe(self, question: str) -> str: + """ + Generate SQL with error handling. + """ + with self.lock: + if self.vanna_instance is None: + raise RuntimeError("VannaManager: No instance available") + + try: + logger.debug(f"VannaManager: Generating SQL for question: {question}") + + # Generate SQL with allow_llm_to_see_data=True for database introspection + sql = self.vanna_instance.generate_sql( + question=question, allow_llm_to_see_data=True + ) + + # Validate SQL response + if not sql or sql.strip() == "": + raise ValueError("Empty SQL response") + + return sql + + except Exception as e: + logger.error(f"VannaManager: Error in SQL generation: {e}") + raise + + def force_reset(self): + """ + Force reset the instance (useful for cleanup). + """ + with self.lock: + if self.vanna_instance: + logger.debug(f"VannaManager: Resetting instance for {self.config_key}") + self.vanna_instance = None + + def get_stats(self) -> Dict: + """ + Get manager statistics. + """ + return { + "config_key": self.config_key, + "instance_id": id(self.vanna_instance) if self.vanna_instance else None, + "has_instance": self.vanna_instance is not None, + "db_type": self.db_type, + } + + @classmethod + def create_with_config( + cls, + vanna_llm_config, + vanna_embedder_config, + vector_store_path: str, + db_connection_string: str, + db_type: str = "sqlite", + training_data_path: str | None = None, + nvidia_api_key: str | None = None, + ): + """ + Class method to create a VannaManager with full configuration. + Uses create_config_key to ensure singleton behavior based on configuration. + + Args: + vanna_llm_config: LLM configuration object + vanna_embedder_config: Embedder configuration object + vector_store_path: Path to ChromaDB vector store + db_connection_string: Database connection string + db_type: Type of database - 'sqlite', 'postgres', or 'sql' + training_data_path: Path to YAML training data file + nvidia_api_key: NVIDIA API key (optional) + """ + config_key = create_config_key( + vanna_llm_config, + vanna_embedder_config, + vector_store_path, + db_connection_string, + db_type, + ) + + # Create instance with just config_key (singleton pattern) + instance = cls(config_key) + + # If this is a new instance that hasn't been configured yet, set the configuration + if not hasattr(instance, "vanna_llm_config") or instance.vanna_llm_config is None: + instance.vanna_llm_config = vanna_llm_config + instance.vanna_embedder_config = vanna_embedder_config + instance.vector_store_path = vector_store_path + instance.db_connection_string = db_connection_string + instance.db_type = db_type + instance.training_data_path = training_data_path + instance.nvidia_api_key = nvidia_api_key + + # Create Vanna instance immediately if all config is available + if instance.vanna_instance is None: + logger.debug("VannaManager: Creating Vanna instance for existing singleton") + instance.vanna_instance = instance._create_instance() + + return instance + + +def create_config_key( + vanna_llm_config, + vanna_embedder_config, + vector_store_path: str, + db_connection_string: str, + db_type: str = "sqlite", +) -> str: + """ + Create a unique configuration key for the VannaManager singleton. + + Args: + vanna_llm_config: LLM configuration object + vanna_embedder_config: Embedder configuration object + vector_store_path: Path to vector store + db_connection_string: Database connection string + db_type: Type of database + + Returns: + str: Unique configuration key + """ + config_str = f"{vanna_llm_config.model_name}_{vanna_embedder_config.model_name}_{vector_store_path}_{db_connection_string}_{db_type}" + return hashlib.md5(config_str.encode()).hexdigest()[:12] + diff --git a/src/nat/retriever/sql_retriever/vanna_util.py b/src/nat/retriever/sql_retriever/vanna_util.py new file mode 100644 index 000000000..33b51c0cb --- /dev/null +++ b/src/nat/retriever/sql_retriever/vanna_util.py @@ -0,0 +1,948 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Vanna utilities for SQL generation using NVIDIA NIM services.""" + +import logging + +from langchain_nvidia import NVIDIAEmbeddings +from vanna.base import VannaBase +from vanna.chromadb import ChromaDB_VectorStore + +logger = logging.getLogger(__name__) + + +class NIMCustomLLM(VannaBase): + """Custom LLM implementation for Vanna using NVIDIA NIM.""" + + def __init__(self, config=None): + VannaBase.__init__(self, config=config) + + if not config: + raise ValueError("config must be passed") + + # default parameters - can be overrided using config + self.temperature = 0.7 + + if "temperature" in config: + self.temperature = config["temperature"] + + # If only config is passed + if "api_key" not in config: + raise ValueError("config must contain a NIM api_key") + + if "model" not in config: + raise ValueError("config must contain a NIM model") + + api_key = config["api_key"] + model = config["model"] + + # Initialize ChatNVIDIA client + from langchain_nvidia import ChatNVIDIA + + self.client = ChatNVIDIA( + api_key=api_key, + model=model, + temperature=self.temperature, + ) + self.model = model + + def system_message(self, message: str) -> dict: + """Create a system message.""" + return { + "role": "system", + "content": message + "\n DO NOT PRODUCE MARKDOWN, ONLY RESPOND IN PLAIN TEXT", + } + + def user_message(self, message: str) -> dict: + """Create a user message.""" + return {"role": "user", "content": message} + + def assistant_message(self, message: str) -> dict: + """Create an assistant message.""" + return {"role": "assistant", "content": message} + + def submit_prompt(self, prompt, **kwargs) -> str: + """Submit a prompt to the LLM.""" + if prompt is None: + raise Exception("Prompt is None") + + if len(prompt) == 0: + raise Exception("Prompt is empty") + + # Count the number of tokens in the message log + # Use 4 as an approximation for the number of characters per token + num_tokens = 0 + for message in prompt: + num_tokens += len(message["content"]) / 4 + logger.debug(f"Using model {self.model} for {num_tokens} tokens (approx)") + + logger.debug(f"Submitting prompt with {len(prompt)} messages") + logger.debug(f"Prompt content preview: {str(prompt)[:500]}...") + + try: + response = self.client.invoke(prompt) + logger.debug(f"Response type: {type(response)}") + logger.debug(f"Response content type: {type(response.content)}") + logger.debug( + f"Response content length: {len(response.content) if response.content else 0}" + ) + logger.debug( + f"Response content preview: {response.content[:200] if response.content else 'None'}..." + ) + return response.content + except Exception as e: + logger.error(f"Error in submit_prompt: {e}") + logger.error(f"Error type: {type(e)}") + import traceback + + logger.error(f"Full traceback: {traceback.format_exc()}") + raise + + +class NIMVanna(ChromaDB_VectorStore, NIMCustomLLM): + """Vanna implementation using NVIDIA NIM for LLM and ChromaDB for vector storage.""" + + def __init__(self, VectorConfig=None, LLMConfig=None): + ChromaDB_VectorStore.__init__(self, config=VectorConfig) + NIMCustomLLM.__init__(self, config=LLMConfig) + + +class ElasticVectorStore(VannaBase): + """ + Elasticsearch-based vector store for Vanna. + + This class provides vector storage and retrieval capabilities using Elasticsearch's + dense_vector field type and kNN search functionality. + + Configuration: + config: Dictionary with the following keys: + - url: Elasticsearch connection URL (e.g., "http://localhost:9200") + - index_name: Name of the Elasticsearch index to use (default: "vanna_vectors") + - api_key: Optional API key for authentication + - username: Optional username for basic auth + - password: Optional password for basic auth + - embedding_function: Function to generate embeddings (required) + """ + + def __init__(self, config=None): + VannaBase.__init__(self, config=config) + + if not config: + raise ValueError("config must be passed for ElasticVectorStore") + + # Elasticsearch connection parameters + self.url = config.get("url", "http://localhost:9200") + self.index_name = config.get("index_name", "vanna_vectors") + self.api_key = config.get("api_key") + self.username = config.get("username") + self.password = config.get("password") + + # Embedding function (required) + if "embedding_function" not in config: + raise ValueError("embedding_function must be provided in config") + self.embedding_function = config["embedding_function"] + + # Initialize Elasticsearch client + self._init_elasticsearch_client() + + # Create index if it doesn't exist + self._create_index_if_not_exists() + + logger.info(f"ElasticVectorStore initialized with index: {self.index_name}") + + def _init_elasticsearch_client(self): + """Initialize the Elasticsearch client with authentication.""" + try: + from elasticsearch import Elasticsearch + except ImportError: + raise ImportError( + "elasticsearch package is required for ElasticVectorStore. " + "Install it with: pip install elasticsearch" + ) + + # Build client kwargs + client_kwargs = {} + + if self.api_key: + client_kwargs["api_key"] = self.api_key + elif self.username and self.password: + client_kwargs["basic_auth"] = (self.username, self.password) + + self.es_client = Elasticsearch(self.url, **client_kwargs) + + # Test connection (try but don't fail if ping doesn't work) + try: + if self.es_client.ping(): + logger.info(f"Successfully connected to Elasticsearch at {self.url}") + else: + logger.warning(f"Elasticsearch ping failed, but will try to proceed at {self.url}") + except Exception as e: + logger.warning(f"Elasticsearch ping check failed ({e}), but will try to proceed") + + def _create_index_if_not_exists(self): + """Create the Elasticsearch index with appropriate mappings if it doesn't exist.""" + if self.es_client.indices.exists(index=self.index_name): + logger.debug(f"Index {self.index_name} already exists") + return + + # Get embedding dimension by creating a test embedding + test_embedding = self._generate_embedding("test") + embedding_dim = len(test_embedding) + + # Index mapping with dense_vector field for embeddings + index_mapping = { + "mappings": { + "properties": { + "id": {"type": "keyword"}, + "text": {"type": "text"}, + "embedding": { + "type": "dense_vector", + "dims": embedding_dim, + "index": True, + "similarity": "cosine" + }, + "metadata": {"type": "object", "enabled": True}, + "type": {"type": "keyword"}, # ddl, documentation, sql + "created_at": {"type": "date"} + } + } + } + + self.es_client.indices.create(index=self.index_name, body=index_mapping) + logger.info(f"Created Elasticsearch index: {self.index_name}") + + def _generate_embedding(self, text: str) -> list[float]: + """Generate embedding for a given text using the configured embedding function.""" + if hasattr(self.embedding_function, 'embed_query'): + # NVIDIA embedding function returns [[embedding]] + result = self.embedding_function.embed_query(text) + if isinstance(result, list) and len(result) > 0: + if isinstance(result[0], list): + return result[0] # Extract the inner list + return result # type: ignore[return-value] + return result # type: ignore[return-value] + elif callable(self.embedding_function): + # Generic callable + result = self.embedding_function(text) + if isinstance(result, list) and len(result) > 0: + if isinstance(result[0], list): + return result[0] + return result # type: ignore[return-value] + return result # type: ignore[return-value] + else: + raise ValueError("embedding_function must be callable or have embed_query method") + + def add_ddl(self, ddl: str, **kwargs) -> str: + """ + Add a DDL statement to the vector store. + + Args: + ddl: The DDL statement to store + **kwargs: Additional metadata + + Returns: + Document ID + """ + import hashlib + from datetime import datetime + + # Generate document ID + doc_id = hashlib.md5(ddl.encode()).hexdigest() + + # Generate embedding + embedding = self._generate_embedding(ddl) + + # Create document + doc = { + "id": doc_id, + "text": ddl, + "embedding": embedding, + "type": "ddl", + "metadata": kwargs, + "created_at": datetime.utcnow().isoformat() + } + + # Index document + self.es_client.index(index=self.index_name, id=doc_id, document=doc) + logger.debug(f"Added DDL to Elasticsearch: {doc_id}") + + return doc_id + + def add_documentation(self, documentation: str, **kwargs) -> str: + """ + Add documentation to the vector store. + + Args: + documentation: The documentation text to store + **kwargs: Additional metadata + + Returns: + Document ID + """ + import hashlib + from datetime import datetime + + doc_id = hashlib.md5(documentation.encode()).hexdigest() + embedding = self._generate_embedding(documentation) + + doc = { + "id": doc_id, + "text": documentation, + "embedding": embedding, + "type": "documentation", + "metadata": kwargs, + "created_at": datetime.utcnow().isoformat() + } + + self.es_client.index(index=self.index_name, id=doc_id, document=doc) + logger.debug(f"Added documentation to Elasticsearch: {doc_id}") + + return doc_id + + def add_question_sql(self, question: str, sql: str, **kwargs) -> str: + """ + Add a question-SQL pair to the vector store. + + Args: + question: The natural language question + sql: The corresponding SQL query + **kwargs: Additional metadata + + Returns: + Document ID + """ + import hashlib + from datetime import datetime + + # Combine question and SQL for embedding + combined_text = f"Question: {question}\nSQL: {sql}" + doc_id = hashlib.md5(combined_text.encode()).hexdigest() + embedding = self._generate_embedding(question) + + doc = { + "id": doc_id, + "text": combined_text, + "embedding": embedding, + "type": "sql", + "metadata": { + "question": question, + "sql": sql, + **kwargs + }, + "created_at": datetime.utcnow().isoformat() + } + + self.es_client.index(index=self.index_name, id=doc_id, document=doc) + logger.debug(f"Added question-SQL pair to Elasticsearch: {doc_id}") + + return doc_id + + def get_similar_question_sql(self, question: str, **kwargs) -> list: + """ + Retrieve similar question-SQL pairs using vector similarity search. + + Args: + question: The question to find similar examples for + **kwargs: Additional parameters (e.g., top_k) + + Returns: + List of similar documents + """ + top_k = kwargs.get("top_k", 10) + + # Generate query embedding + query_embedding = self._generate_embedding(question) + + # Build kNN search query + search_query = { + "knn": { + "field": "embedding", + "query_vector": query_embedding, + "k": top_k, + "num_candidates": top_k * 2, + "filter": {"term": {"type": "sql"}} + }, + "_source": ["text", "metadata", "type"] + } + + # Execute search + response = self.es_client.search(index=self.index_name, body=search_query) + + # Extract results + results = [] + for hit in response["hits"]["hits"]: + source = hit["_source"] + results.append({ + "question": source["metadata"].get("question", ""), + "sql": source["metadata"].get("sql", ""), + "score": hit["_score"] + }) + + logger.debug(f"Found {len(results)} similar question-SQL pairs") + return results + + def get_related_ddl(self, question: str, **kwargs) -> list: + """ + Retrieve related DDL statements using vector similarity search. + + Args: + question: The question to find related DDL for + **kwargs: Additional parameters (e.g., top_k) + + Returns: + List of related DDL statements + """ + top_k = kwargs.get("top_k", 10) + query_embedding = self._generate_embedding(question) + + search_query = { + "knn": { + "field": "embedding", + "query_vector": query_embedding, + "k": top_k, + "num_candidates": top_k * 2, + "filter": {"term": {"type": "ddl"}} + }, + "_source": ["text"] + } + + response = self.es_client.search(index=self.index_name, body=search_query) + + results = [hit["_source"]["text"] for hit in response["hits"]["hits"]] + logger.debug(f"Found {len(results)} related DDL statements") + return results + + def get_related_documentation(self, question: str, **kwargs) -> list: + """ + Retrieve related documentation using vector similarity search. + + Args: + question: The question to find related documentation for + **kwargs: Additional parameters (e.g., top_k) + + Returns: + List of related documentation + """ + top_k = kwargs.get("top_k", 10) + query_embedding = self._generate_embedding(question) + + search_query = { + "knn": { + "field": "embedding", + "query_vector": query_embedding, + "k": top_k, + "num_candidates": top_k * 2, + "filter": {"term": {"type": "documentation"}} + }, + "_source": ["text"] + } + + response = self.es_client.search(index=self.index_name, body=search_query) + + results = [hit["_source"]["text"] for hit in response["hits"]["hits"]] + logger.debug(f"Found {len(results)} related documentation entries") + return results + + def remove_training_data(self, id: str, **kwargs) -> bool: + """ + Remove a training data entry by ID. + + Args: + id: The document ID to remove + **kwargs: Additional parameters + + Returns: + True if successful + """ + try: + self.es_client.delete(index=self.index_name, id=id) + logger.debug(f"Removed training data: {id}") + return True + except Exception as e: + logger.error(f"Error removing training data {id}: {e}") + return False + + def generate_embedding(self, data: str, **kwargs) -> list[float]: + """ + Generate embedding for given data (required by Vanna base class). + + Args: + data: Text to generate embedding for + **kwargs: Additional parameters + + Returns: + Embedding vector + """ + return self._generate_embedding(data) + + def get_training_data(self, **kwargs) -> list: + """ + Get all training data from the vector store (required by Vanna base class). + + Args: + **kwargs: Additional parameters + + Returns: + List of training data entries + """ + try: + # Query all documents + query = { + "query": {"match_all": {}}, + "size": 10000 # Adjust based on expected data size + } + + response = self.es_client.search(index=self.index_name, body=query) + + training_data = [] + for hit in response["hits"]["hits"]: + source = hit["_source"] + training_data.append({ + "id": hit["_id"], + "type": source.get("type"), + "text": source.get("text"), + "metadata": source.get("metadata", {}) + }) + + return training_data + except Exception as e: + logger.error(f"Error getting training data: {e}") + return [] + + +class ElasticNIMVanna(ElasticVectorStore, NIMCustomLLM): + """ + Vanna implementation using NVIDIA NIM for LLM and Elasticsearch for vector storage. + + This class combines ElasticVectorStore for vector operations with NIMCustomLLM + for SQL generation, providing an alternative to ChromaDB-based storage. + + Example: + >>> vanna = ElasticNIMVanna( + ... VectorConfig={ + ... "url": "http://localhost:9200", + ... "index_name": "my_sql_vectors", + ... "username": "elastic", + ... "password": "changeme", + ... "embedding_function": NVIDIAEmbeddingFunction( + ... api_key="your-api-key", + ... model="nvidia/llama-3.2-nv-embedqa-1b-v2" + ... ) + ... }, + ... LLMConfig={ + ... "api_key": "your-api-key", + ... "model": "meta/llama-3.1-70b-instruct" + ... } + ... ) + """ + + def __init__(self, VectorConfig=None, LLMConfig=None): + ElasticVectorStore.__init__(self, config=VectorConfig) + NIMCustomLLM.__init__(self, config=LLMConfig) + + +class NVIDIAEmbeddingFunction: + """ + A class that can be used as a replacement for chroma's DefaultEmbeddingFunction. + It takes in input (text or list of texts) and returns embeddings using NVIDIA's API. + + This class fixes two major interface compatibility issues between ChromaDB and NVIDIA embeddings: + + 1. INPUT FORMAT MISMATCH: + - ChromaDB passes ['query text'] (list) to embed_query() + - But langchain_nvidia's embed_query() expects 'query text' (string) + - When list is passed, langchain does [text] internally → [['query text']] → API 500 error + - FIX: Detect list input and extract string before calling langchain + + 2. OUTPUT FORMAT MISMATCH: + - ChromaDB expects embed_query() to return [[embedding_vector]] (list of embeddings) + - But langchain returns [embedding_vector] (single embedding vector) + - This causes: TypeError: 'float' object cannot be converted to 'Sequence' + - FIX: Wrap single embedding in list: return [embeddings] + """ + + def __init__(self, api_key, model="nvidia/llama-3.2-nv-embedqa-1b-v2"): + """ + Initialize the embedding function with the API key and model name. + + Parameters: + - api_key (str): The API key for authentication. + - model (str): The model name to use for embeddings. + Default: nvidia/llama-3.2-nv-embedqa-1b-v2 (tested and working) + """ + self.api_key = api_key + self.model = model + + logger.info(f"Initializing NVIDIA embeddings with model: {model}") + logger.debug(f"API key length: {len(api_key) if api_key else 0}") + + self.embeddings = NVIDIAEmbeddings( + api_key=api_key, model_name=model, input_type="query", truncate="NONE" + ) + logger.info("Successfully initialized NVIDIA embeddings") + + def __call__(self, input): + """ + Call method to make the object callable, as required by chroma's EmbeddingFunction interface. + + NOTE: This method is used by ChromaDB for batch embedding operations. + The embed_query() method above handles the single query case with the critical fixes. + + Parameters: + - input (str or list): The input data for which embeddings need to be generated. + + Returns: + - embedding (list): The embedding vector(s) for the input data. + """ + logger.debug(f"__call__ method called with input type: {type(input)}") + logger.debug(f"__call__ input: {input}") + + # Ensure input is a list, as required by ChromaDB + if isinstance(input, str): + input_data = [input] + else: + input_data = input + + logger.debug(f"Processing {len(input_data)} texts for embedding") + + # Generate embeddings for each text + embeddings = [] + for i, text in enumerate(input_data): + logger.debug(f"Embedding text {i+1}/{len(input_data)}: {text[:50]}...") + embedding = self.embeddings.embed_query(text) + embeddings.append(embedding) + + logger.debug(f"Generated {len(embeddings)} embeddings") + # Always return a list of embeddings for ChromaDB + return embeddings + + def name(self): + """ + Returns a custom name for the embedding function. + + Returns: + str: The name of the embedding function. + """ + return "NVIDIA Embedding Function" + + def embed_query(self, input: str) -> list[list[float]]: + """ + Generate embeddings for a single query. + + ChromaDB calls this method with ['query text'] (list) but langchain_nvidia expects 'query text' (string). + We must extract the string from the list to prevent API 500 errors. + + ChromaDB expects this method to return [[embedding_vector]] (list of embeddings) + but langchain returns [embedding_vector] (single embedding). We wrap it in a list. + """ + logger.debug(f"Embedding query: {input}") + logger.debug(f"Input type: {type(input)}") + logger.debug(f"Using model: {self.model}") + + # Handle ChromaDB's list input format + # ChromaDB sometimes passes a list instead of a string + # Extract the string from the list if needed + if isinstance(input, list): + if len(input) == 1: + query_text = input[0] + logger.debug(f"Extracted string from list: {query_text}") + else: + logger.error(f"Unexpected list length: {len(input)}") + raise ValueError( + f"Expected single string or list with one element, got list with {len(input)} elements" + ) + else: + query_text = input + + try: + # Call langchain_nvidia with the extracted string + embeddings = self.embeddings.embed_query(query_text) + logger.debug( + f"Successfully generated embeddings of length: {len(embeddings) if embeddings else 0}" + ) + + # Wrap single embedding in list for ChromaDB compatibility + # ChromaDB expects a list of embeddings, even for a single query + return [embeddings] + except Exception as e: + logger.error(f"Error generating embeddings for query: {e}") + logger.error(f"Error type: {type(e)}") + logger.error(f"Query text: {query_text}") + import traceback + + logger.error(f"Full traceback: {traceback.format_exc()}") + raise + + def embed_documents(self, input: list[str]) -> list[list[float]]: + """ + Generate embeddings for multiple documents. + + This function expects a list of strings. If it's a list of lists of strings, flatten it to handle cases + where the input is unexpectedly nested. + """ + logger.debug(f"Embedding {len(input)} documents...") + logger.debug(f"Using model: {self.model}") + + try: + embeddings = self.embeddings.embed_documents(input) + logger.debug("Successfully generated document embeddings") + return embeddings + except Exception as e: + logger.error(f"Error generating document embeddings: {e}") + logger.error(f"Error type: {type(e)}") + logger.error(f"Input documents count: {len(input)}") + import traceback + + logger.error(f"Full traceback: {traceback.format_exc()}") + raise + + +def chunk_documentation(text: str, max_chars: int = 1500) -> list: + """ + Split long documentation into smaller chunks to avoid token limits. + + Args: + text: The documentation text to chunk + max_chars: Maximum characters per chunk (approximate) + + Returns: + List of text chunks + """ + if len(text) <= max_chars: + return [text] + + chunks = [] + # Split by paragraphs first + paragraphs = text.split("\n\n") + current_chunk = "" + + for paragraph in paragraphs: + # If adding this paragraph would exceed the limit, save current chunk and start new one + if len(current_chunk) + len(paragraph) + 2 > max_chars and current_chunk: + chunks.append(current_chunk.strip()) + current_chunk = paragraph + else: + if current_chunk: + current_chunk += "\n\n" + paragraph + else: + current_chunk = paragraph + + # Add the last chunk if it exists + if current_chunk.strip(): + chunks.append(current_chunk.strip()) + + # If any chunk is still too long, split it further + final_chunks = [] + for chunk in chunks: + if len(chunk) > max_chars: + # Split long chunk into sentences + sentences = chunk.split(". ") + temp_chunk = "" + for sentence in sentences: + if len(temp_chunk) + len(sentence) + 2 > max_chars and temp_chunk: + final_chunks.append(temp_chunk.strip() + ".") + temp_chunk = sentence + else: + if temp_chunk: + temp_chunk += ". " + sentence + else: + temp_chunk = sentence + if temp_chunk.strip(): + final_chunks.append(temp_chunk.strip()) + else: + final_chunks.append(chunk) + + return final_chunks + + +def init_vanna(vn, training_data_path: str | None = None): + """ + Initialize and train a Vanna instance for SQL generation using configurable training data. + + This function configures a Vanna SQL generation agent with training data loaded from a YAML file, + making it scalable for different SQL data sources with different contexts. + + Args: + vn: Vanna instance to be trained and configured + training_data_path: Path to YAML file containing training data. If None, no training is applied. + + Returns: + None: Modifies the Vanna instance in-place + + Example: + >>> from vanna.chromadb import ChromaDB_VectorStore + >>> vn = NIMCustomLLM(config) & ChromaDB_VectorStore() + >>> vn.connect_to_sqlite("path/to/database.db") + >>> init_vanna(vn, "path/to/training_data.yaml") + >>> # Vanna is now ready to generate SQL queries + """ + import os + + logger.info("=== Starting Vanna initialization ===") + + # Get and train DDL from sqlite_master (if connected to SQLite) + if hasattr(vn, "run_sql") and vn.run_sql_is_set: + logger.info("Loading DDL from database...") + try: + # Try SQLite-specific query first + try: + df_ddl = vn.run_sql( + "SELECT type, sql FROM sqlite_master WHERE sql is not null" + ) + except Exception: + # For non-SQLite databases, try standard information_schema + try: + df_ddl = vn.run_sql( + """ + SELECT table_name, column_name, data_type + FROM information_schema.columns + WHERE table_schema = 'public' + """ + ) + except Exception as e: + logger.warning(f"Could not auto-load DDL from database: {e}") + df_ddl = None + + if df_ddl is not None and not df_ddl.empty: + ddl_count = len(df_ddl) + logger.info(f"Found {ddl_count} DDL statements in database") + + if "sql" in df_ddl.columns: + for i, ddl in enumerate(df_ddl["sql"].to_list(), 1): + if ddl: + logger.debug(f"Training DDL {i}/{ddl_count}: {ddl[:100]}...") + vn.train(ddl=ddl) + + logger.info( + f"Successfully trained {ddl_count} DDL statements from database" + ) + except Exception as e: + logger.error(f"Error loading DDL from database: {e}") + # Continue with training data from YAML + + # Load and apply training data from YAML file + if training_data_path: + logger.info(f"Training data path provided: {training_data_path}") + + if os.path.exists(training_data_path): + logger.info("Training data file exists, loading YAML...") + + try: + import yaml + + with open(training_data_path, "r") as f: + training_data = yaml.safe_load(f) + + logger.info("Successfully loaded YAML training data") + logger.info( + f"Training data keys: {list(training_data.keys()) if training_data else 'None'}" + ) + + # Train DDL statements + ddl_statements = training_data.get("ddl", []) + logger.info(f"Found {len(ddl_statements)} DDL statements") + + ddl_trained = 0 + for i, ddl_statement in enumerate(ddl_statements, 1): + if ddl_statement.strip(): # Only train non-empty statements + logger.debug(f"Training DDL {i}: {ddl_statement[:100]}...") + vn.train(ddl=ddl_statement) + ddl_trained += 1 + else: + logger.warning(f"Skipping empty DDL statement at index {i}") + + logger.info( + f"Successfully trained {ddl_trained}/{len(ddl_statements)} DDL statements" + ) + + # Train documentation with chunking + documentation_list = training_data.get("documentation", []) + logger.info(f"Found {len(documentation_list)} documentation entries") + + doc_chunks = [] + for i, doc_entry in enumerate(documentation_list, 1): + if doc_entry.strip(): + logger.debug( + f"Processing documentation entry {i}: {doc_entry[:100]}..." + ) + # Chunk each documentation entry to avoid token limits + entry_chunks = chunk_documentation(doc_entry) + doc_chunks.extend(entry_chunks) + else: + logger.warning(f"Skipping empty documentation entry at index {i}") + + logger.info(f"Split documentation into {len(doc_chunks)} total chunks") + + for i, chunk in enumerate(doc_chunks, 1): + try: + logger.debug( + f"Training documentation chunk {i}/{len(doc_chunks)} ({len(chunk)} chars)" + ) + vn.train(documentation=chunk) + except Exception as e: + logger.error(f"Error training documentation chunk {i}: {e}") + # Continue with other chunks + + logger.info(f"Successfully trained {len(doc_chunks)} documentation chunks") + + # Train question-SQL pairs + question_sql_pairs = training_data.get("sql", []) + logger.info(f"Found {len(question_sql_pairs)} question-SQL pairs") + + pairs_trained = 0 + for i, pair in enumerate(question_sql_pairs, 1): + question = pair.get("question", "") + sql = pair.get("sql", "") + if question.strip() and sql.strip(): # Only train non-empty pairs + logger.debug( + f"Training question-SQL pair {i}: Q='{question[:50]}...' SQL='{sql[:50]}...'" + ) + vn.train(question=question, sql=sql) + pairs_trained += 1 + else: + if not question.strip(): + logger.warning(f"Skipping question-SQL pair {i}: empty question") + if not sql.strip(): + logger.warning(f"Skipping question-SQL pair {i}: empty SQL") + + logger.info( + f"Successfully trained {pairs_trained}/{len(question_sql_pairs)} question-SQL pairs" + ) + + # Summary + total_trained = ddl_trained + len(doc_chunks) + pairs_trained + logger.info("=== Training Summary ===") + logger.info(f" DDL statements: {ddl_trained}") + logger.info(f" Documentation chunks: {len(doc_chunks)}") + logger.info(f" Question-SQL pairs: {pairs_trained}") + logger.info(f" Total items trained: {total_trained}") + + except Exception as e: + import yaml + + if isinstance(e, yaml.YAMLError): + logger.error(f"Error parsing YAML file {training_data_path}: {e}") + else: + logger.error(f"Error loading training data from {training_data_path}: {e}") + raise + else: + logger.warning(f"Training data file does not exist: {training_data_path}") + logger.warning("Proceeding without YAML training data") + else: + logger.info("No training data path provided, skipping YAML training") + + logger.info("=== Vanna initialization completed ===") +