Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ dependencies = [
# Keep sorted!!!
"aioboto3>=11.0.0",
"authlib~=1.5",
"chromadb~=1.1.0",
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

🧩 Analysis chain

Verify chromadb version and precision.

The constraint chromadb~=1.1.0 uses 3-digit precision, while project guidelines recommend 2 digits (e.g., ~=1.1). Verify whether the patch-level pin is necessary or if chromadb~=1.1 would suffice.

Run the following to check the latest chromadb version and any known security advisories:


🏁 Script executed:

#!/bin/bash
# Check PyPI for latest chromadb version
echo "=== Latest chromadb version ==="
curl -s https://pypi.org/pypi/chromadb/json | jq -r '.info.version'

# Check for security advisories
echo -e "\n=== Security advisories ==="
gh api graphql -f query='
{
  securityVulnerabilities(first: 5, ecosystem: PIP, package: "chromadb") {
    nodes {
      advisory {
        summary
        severity
        publishedAt
      }
      vulnerableVersionRange
      firstPatchedVersion {
        identifier
      }
    }
  }
}'

Length of output: 116


Use 2-digit precision for chromadb pin
Change chromadb~=1.1.0 to chromadb~=1.1 in pyproject.toml to comply with the guideline of using two-digit precision for ~= and allow future patch-level updates.

🤖 Prompt for AI Agents
In pyproject.toml around line 23, the dependency is pinned as "chromadb~=1.1.0"
which uses three-digit precision; change it to "chromadb~=1.1" so the ~=
operator uses two-digit precision and permits future patch-level updates (update
the single dependency entry at line 23 accordingly).

"click~=8.1",
"colorama~=0.4.6",
"datasets~=4.0", # workaround for uv's solver choosing different versions of datasets based on sys_platform
Expand All @@ -37,6 +38,7 @@ dependencies = [
"openinference-semantic-conventions~=0.1.14",
"openpyxl~=3.1",
"optuna~=4.4.0",
"pandas~=2.0",
"pip>=24.3.1",
"pkce==1.0.3",
"pkginfo~=1.12",
Expand All @@ -48,6 +50,7 @@ dependencies = [
"rich~=13.9",
"tabulate~=0.9",
"uvicorn[standard]~=0.34",
"vanna~=0.7.9",
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

🧩 Analysis chain

Verify vanna version and consider 2-digit precision.

The constraint vanna~=0.7.9 uses 3-digit precision. Per project guidelines, prefer 2-digit precision (e.g., vanna~=0.7) unless a specific patch is required. Additionally, verify this is the latest stable version.


🏁 Script executed:

#!/bin/bash
# Check PyPI for latest vanna version
echo "=== Latest vanna version ==="
curl -s https://pypi.org/pypi/vanna/json | jq -r '.info.version'

# Check for security advisories
echo -e "\n=== Security advisories ==="
gh api graphql -f query='
{
  securityVulnerabilities(first: 5, ecosystem: PIP, package: "vanna") {
    nodes {
      advisory {
        summary
        severity
        publishedAt
      }
      vulnerableVersionRange
      firstPatchedVersion {
        identifier
      }
    }
  }
}'

Length of output: 696


Use 2-digit precision for vanna constraint
0.7.9 is the latest stable release and contains no unpatched vulnerabilities, so per project guidelines switch from vanna~=0.7.9 to:

vanna~=0.7
🤖 Prompt for AI Agents
In pyproject.toml around line 53, the dependency pin uses three-segment
precision "vanna~=0.7.9"; update it to two-segment precision per project
guidelines by changing the constraint to "vanna~=0.7" so it allows any
compatible 0.7.x release while maintaining the intended compatibility range.

"wikipedia~=1.4",
]
requires-python = ">=3.11,<3.14"
Expand Down Expand Up @@ -80,6 +83,7 @@ mcp = ["nvidia-nat-mcp"]
mem0ai = ["nvidia-nat-mem0ai"]
opentelemetry = ["nvidia-nat-opentelemetry"]
phoenix = ["nvidia-nat-phoenix"]
postgres = ["psycopg2-binary~=2.9"]
profiling = ["nvidia-nat-profiling"] # meta-package
ragaai = ["nvidia-nat-ragaai"]
mysql = ["nvidia-nat-mysql"]
Expand Down
1 change: 1 addition & 0 deletions src/nat/retriever/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@
# Import any providers which need to be automatically registered here
import nat.retriever.milvus.register
import nat.retriever.nemo_retriever.register
import nat.retriever.sql_retriever.register
21 changes: 21 additions & 0 deletions src/nat/retriever/sql_retriever/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""SQL Retriever module for NeMo Agent Toolkit."""

from nat.retriever.sql_retriever.sql_retriever import SQLRetriever

__all__ = ["SQLRetriever"]

123 changes: 123 additions & 0 deletions src/nat/retriever/sql_retriever/register.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Registration module for SQL Retriever."""

import logging

from pydantic import BaseModel
from pydantic import Field

from nat.builder.builder import Builder
from nat.cli.register_workflow import register_retriever
from nat.data_models.retriever import RetrieverBaseConfig
from nat.retriever.sql_retriever.sql_retriever import SQLRetriever

logger = logging.getLogger(__name__)


class SQLRetrieverConfig(RetrieverBaseConfig, name="sql_retriever"):
"""
Configuration for SQL Retriever.
This retriever uses Vanna AI with NVIDIA NIM to convert natural language queries
into SQL and retrieve data from SQL databases.
Supported database types:
- sqlite: SQLite databases (local file-based)
- postgres/postgresql: PostgreSQL databases
- sql: Generic SQL databases via SQLAlchemy (MySQL, SQL Server, Oracle, etc.)
"""

llm_name: str = Field(description="Name of the LLM to use for SQL generation")
embedding_name: str = Field(
description="Name of the embedding model to use for Vanna training data"
)
vector_store_path: str = Field(
description="Path to ChromaDB vector store for Vanna training data"
)
db_connection_string: str = Field(
description=(
"Database connection string. Format depends on db_type:\n"
"- sqlite: Path to .db file (e.g., '/path/to/database.db')\n"
"- postgres: Connection string (e.g., 'postgresql://user:pass@host:port/db')\n"
"- sql: SQLAlchemy connection string (e.g., 'mysql+pymysql://user:pass@host/db')"
)
)
db_type: str = Field(
default="sqlite",
description="Type of database: 'sqlite', 'postgres', or 'sql' (generic SQL via SQLAlchemy)",
)
training_data_path: str | None = Field(
default=None,
description="Path to YAML file containing Vanna training data (DDL, documentation, question-SQL pairs)",
)
max_results: int = Field(
default=100,
description="Maximum number of results to return from SQL queries",
)
nvidia_api_key: str | None = Field(
default=None,
description="NVIDIA API key (optional, defaults to NVIDIA_API_KEY environment variable)",
)


@register_retriever(config_type=SQLRetrieverConfig)
async def create_sql_retriever(config: SQLRetrieverConfig, builder: Builder):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add return type hint.

The function signature is missing a return type hint. Per coding guidelines, all public APIs must have type hints on parameters and return values.

Apply this diff:

 @register_retriever(config_type=SQLRetrieverConfig)
-async def create_sql_retriever(config: SQLRetrieverConfig, builder: Builder):
+async def create_sql_retriever(config: SQLRetrieverConfig, builder: Builder) -> SQLRetriever:

As per coding guidelines

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
async def create_sql_retriever(config: SQLRetrieverConfig, builder: Builder):
@register_retriever(config_type=SQLRetrieverConfig)
async def create_sql_retriever(config: SQLRetrieverConfig, builder: Builder) -> SQLRetriever:
🤖 Prompt for AI Agents
In src/nat/retriever/sql_retriever/register.py around line 78, the async
function create_sql_retriever(config: SQLRetrieverConfig, builder: Builder) is
missing a return type annotation; update the signature to include the concrete
return type (e.g., -> SQLRetriever), import that type if necessary, and ensure
the annotation matches the actual object returned by the function.

"""
Create and register a SQL Retriever instance.
Args:
config: SQLRetrieverConfig containing all necessary parameters
builder: Builder instance for accessing LLM and embedder configurations
Returns:
SQLRetriever: Configured SQL retriever instance
Example YAML configuration:
```yaml
retrievers:
- name: sql_retriever
type: sql_retriever
llm_name: nim_llm
embedding_name: nim_embeddings
vector_store_path: ./vanna_vector_store
db_connection_string: ./database.db
db_type: sqlite
training_data_path: ./training_data.yaml
max_results: 100
```
"""
logger.info(f"Creating SQL Retriever with config: {config.name}")

# Get LLM and embedder configurations from builder
llm_config = builder.get_llm_config(config.llm_name)
embedder_config = builder.get_embedder_config(config.embedding_name)

# Create SQL retriever instance
retriever = SQLRetriever(
llm_config=llm_config,
embedder_config=embedder_config,
vector_store_path=config.vector_store_path,
db_connection_string=config.db_connection_string,
db_type=config.db_type,
training_data_path=config.training_data_path,
nvidia_api_key=config.nvidia_api_key,
max_results=config.max_results,
)

logger.info(f"SQL Retriever '{config.name}' created successfully")
return retriever

Loading