-
Notifications
You must be signed in to change notification settings - Fork 382
[Feature] SQL Retriever tool using Vanna #894
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
8bdbd16
fd6ead9
4d8a487
d44bd11
5ea8a5a
01a1023
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ dependencies = [ | |
# Keep sorted!!! | ||
"aioboto3>=11.0.0", | ||
"authlib~=1.5", | ||
"chromadb~=1.1.0", | ||
"click~=8.1", | ||
"colorama~=0.4.6", | ||
"datasets~=4.0", # workaround for uv's solver choosing different versions of datasets based on sys_platform | ||
|
@@ -37,6 +38,7 @@ dependencies = [ | |
"openinference-semantic-conventions~=0.1.14", | ||
"openpyxl~=3.1", | ||
"optuna~=4.4.0", | ||
"pandas~=2.0", | ||
"pip>=24.3.1", | ||
"pkce==1.0.3", | ||
"pkginfo~=1.12", | ||
|
@@ -48,6 +50,7 @@ dependencies = [ | |
"rich~=13.9", | ||
"tabulate~=0.9", | ||
"uvicorn[standard]~=0.34", | ||
"vanna~=0.7.9", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion | 🟠 Major 🧩 Analysis chainVerify vanna version and consider 2-digit precision. The constraint 🏁 Script executed: #!/bin/bash
# Check PyPI for latest vanna version
echo "=== Latest vanna version ==="
curl -s https://pypi.org/pypi/vanna/json | jq -r '.info.version'
# Check for security advisories
echo -e "\n=== Security advisories ==="
gh api graphql -f query='
{
securityVulnerabilities(first: 5, ecosystem: PIP, package: "vanna") {
nodes {
advisory {
summary
severity
publishedAt
}
vulnerableVersionRange
firstPatchedVersion {
identifier
}
}
}
}' Length of output: 696 Use 2-digit precision for vanna constraint vanna~=0.7 🤖 Prompt for AI Agents
|
||
"wikipedia~=1.4", | ||
] | ||
requires-python = ">=3.11,<3.14" | ||
|
@@ -80,6 +83,7 @@ mcp = ["nvidia-nat-mcp"] | |
mem0ai = ["nvidia-nat-mem0ai"] | ||
opentelemetry = ["nvidia-nat-opentelemetry"] | ||
phoenix = ["nvidia-nat-phoenix"] | ||
postgres = ["psycopg2-binary~=2.9"] | ||
profiling = ["nvidia-nat-profiling"] # meta-package | ||
ragaai = ["nvidia-nat-ragaai"] | ||
mysql = ["nvidia-nat-mysql"] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""SQL Retriever module for NeMo Agent Toolkit.""" | ||
|
||
from nat.retriever.sql_retriever.sql_retriever import SQLRetriever | ||
|
||
__all__ = ["SQLRetriever"] | ||
|
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,123 @@ | ||||||||
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||||||||
# SPDX-License-Identifier: Apache-2.0 | ||||||||
# | ||||||||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||
# you may not use this file except in compliance with the License. | ||||||||
# You may obtain a copy of the License at | ||||||||
# | ||||||||
# http://www.apache.org/licenses/LICENSE-2.0 | ||||||||
# | ||||||||
# Unless required by applicable law or agreed to in writing, software | ||||||||
# distributed under the License is distributed on an "AS IS" BASIS, | ||||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||
# See the License for the specific language governing permissions and | ||||||||
# limitations under the License. | ||||||||
|
||||||||
"""Registration module for SQL Retriever.""" | ||||||||
|
||||||||
import logging | ||||||||
|
||||||||
from pydantic import BaseModel | ||||||||
from pydantic import Field | ||||||||
|
||||||||
from nat.builder.builder import Builder | ||||||||
from nat.cli.register_workflow import register_retriever | ||||||||
from nat.data_models.retriever import RetrieverBaseConfig | ||||||||
from nat.retriever.sql_retriever.sql_retriever import SQLRetriever | ||||||||
|
||||||||
logger = logging.getLogger(__name__) | ||||||||
|
||||||||
|
||||||||
class SQLRetrieverConfig(RetrieverBaseConfig, name="sql_retriever"): | ||||||||
""" | ||||||||
Configuration for SQL Retriever. | ||||||||
This retriever uses Vanna AI with NVIDIA NIM to convert natural language queries | ||||||||
into SQL and retrieve data from SQL databases. | ||||||||
Supported database types: | ||||||||
- sqlite: SQLite databases (local file-based) | ||||||||
- postgres/postgresql: PostgreSQL databases | ||||||||
- sql: Generic SQL databases via SQLAlchemy (MySQL, SQL Server, Oracle, etc.) | ||||||||
""" | ||||||||
|
||||||||
llm_name: str = Field(description="Name of the LLM to use for SQL generation") | ||||||||
embedding_name: str = Field( | ||||||||
description="Name of the embedding model to use for Vanna training data" | ||||||||
) | ||||||||
vector_store_path: str = Field( | ||||||||
description="Path to ChromaDB vector store for Vanna training data" | ||||||||
) | ||||||||
db_connection_string: str = Field( | ||||||||
description=( | ||||||||
"Database connection string. Format depends on db_type:\n" | ||||||||
"- sqlite: Path to .db file (e.g., '/path/to/database.db')\n" | ||||||||
"- postgres: Connection string (e.g., 'postgresql://user:pass@host:port/db')\n" | ||||||||
"- sql: SQLAlchemy connection string (e.g., 'mysql+pymysql://user:pass@host/db')" | ||||||||
) | ||||||||
) | ||||||||
db_type: str = Field( | ||||||||
default="sqlite", | ||||||||
description="Type of database: 'sqlite', 'postgres', or 'sql' (generic SQL via SQLAlchemy)", | ||||||||
) | ||||||||
training_data_path: str | None = Field( | ||||||||
default=None, | ||||||||
description="Path to YAML file containing Vanna training data (DDL, documentation, question-SQL pairs)", | ||||||||
) | ||||||||
max_results: int = Field( | ||||||||
default=100, | ||||||||
description="Maximum number of results to return from SQL queries", | ||||||||
) | ||||||||
nvidia_api_key: str | None = Field( | ||||||||
default=None, | ||||||||
description="NVIDIA API key (optional, defaults to NVIDIA_API_KEY environment variable)", | ||||||||
) | ||||||||
|
||||||||
|
||||||||
@register_retriever(config_type=SQLRetrieverConfig) | ||||||||
async def create_sql_retriever(config: SQLRetrieverConfig, builder: Builder): | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add return type hint. The function signature is missing a return type hint. Per coding guidelines, all public APIs must have type hints on parameters and return values. Apply this diff: @register_retriever(config_type=SQLRetrieverConfig)
-async def create_sql_retriever(config: SQLRetrieverConfig, builder: Builder):
+async def create_sql_retriever(config: SQLRetrieverConfig, builder: Builder) -> SQLRetriever: As per coding guidelines 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
||||||||
""" | ||||||||
Create and register a SQL Retriever instance. | ||||||||
Args: | ||||||||
config: SQLRetrieverConfig containing all necessary parameters | ||||||||
builder: Builder instance for accessing LLM and embedder configurations | ||||||||
Returns: | ||||||||
SQLRetriever: Configured SQL retriever instance | ||||||||
Example YAML configuration: | ||||||||
```yaml | ||||||||
retrievers: | ||||||||
- name: sql_retriever | ||||||||
type: sql_retriever | ||||||||
llm_name: nim_llm | ||||||||
embedding_name: nim_embeddings | ||||||||
vector_store_path: ./vanna_vector_store | ||||||||
db_connection_string: ./database.db | ||||||||
db_type: sqlite | ||||||||
training_data_path: ./training_data.yaml | ||||||||
max_results: 100 | ||||||||
``` | ||||||||
""" | ||||||||
logger.info(f"Creating SQL Retriever with config: {config.name}") | ||||||||
|
||||||||
# Get LLM and embedder configurations from builder | ||||||||
llm_config = builder.get_llm_config(config.llm_name) | ||||||||
embedder_config = builder.get_embedder_config(config.embedding_name) | ||||||||
|
||||||||
# Create SQL retriever instance | ||||||||
retriever = SQLRetriever( | ||||||||
llm_config=llm_config, | ||||||||
embedder_config=embedder_config, | ||||||||
vector_store_path=config.vector_store_path, | ||||||||
db_connection_string=config.db_connection_string, | ||||||||
db_type=config.db_type, | ||||||||
training_data_path=config.training_data_path, | ||||||||
nvidia_api_key=config.nvidia_api_key, | ||||||||
max_results=config.max_results, | ||||||||
) | ||||||||
|
||||||||
logger.info(f"SQL Retriever '{config.name}' created successfully") | ||||||||
return retriever | ||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
🧩 Analysis chain
Verify chromadb version and precision.
The constraint
chromadb~=1.1.0
uses 3-digit precision, while project guidelines recommend 2 digits (e.g.,~=1.1
). Verify whether the patch-level pin is necessary or ifchromadb~=1.1
would suffice.Run the following to check the latest chromadb version and any known security advisories:
🏁 Script executed:
Length of output: 116
Use 2-digit precision for chromadb pin
Change
chromadb~=1.1.0
tochromadb~=1.1
in pyproject.toml to comply with the guideline of using two-digit precision for~=
and allow future patch-level updates.🤖 Prompt for AI Agents