diff --git a/cratedb_mcp/__main__.py b/cratedb_mcp/__main__.py index 6803bfc..d14f1e6 100644 --- a/cratedb_mcp/__main__.py +++ b/cratedb_mcp/__main__.py @@ -1,17 +1,10 @@ -import hishel import httpx from mcp.server.fastmcp import FastMCP -from .knowledge import DocumentationIndex, Queries, documentation_url_permitted -from .settings import DOCS_CACHE_TTL, HTTP_TIMEOUT, HTTP_URL +from .knowledge import DocumentationIndex, Queries +from .settings import HTTP_TIMEOUT, HTTP_URL from .util.sql import sql_is_permitted -# Configure Hishel, an httpx client with caching. -# Define one hour of caching time. -controller = hishel.Controller(allow_stale=True) -storage = hishel.SQLiteStorage(ttl=DOCS_CACHE_TTL) -client = hishel.CacheClient(controller=controller, storage=storage) - # Load CrateDB documentation outline. documentation_index = DocumentationIndex() @@ -39,9 +32,9 @@ def get_cratedb_documentation_index(): ' Only used to download CrateDB docs.') def fetch_cratedb_docs(link: str): """Fetches a CrateDB documentation link.""" - if not documentation_url_permitted(link): + if not documentation_index.url_permitted(link): raise ValueError(f'Link is not permitted: {link}') - return client.get(link, timeout=HTTP_TIMEOUT).text + return documentation_index.client.get(link, timeout=HTTP_TIMEOUT).text @mcp.tool(description="Returns an aggregation of all CrateDB's schema, tables and their metadata") def get_table_metadata() -> list[dict]: diff --git a/cratedb_mcp/knowledge.py b/cratedb_mcp/knowledge.py index 4ae5a3b..ad0a59f 100644 --- a/cratedb_mcp/knowledge.py +++ b/cratedb_mcp/knowledge.py @@ -1,7 +1,12 @@ # ruff: noqa: E501 +import typing as t + import cachetools +import hishel from cratedb_about import CrateDbKnowledgeOutline +from cratedb_mcp.settings import Settings + class Queries: TABLES_METADATA = """ @@ -108,16 +113,46 @@ class DocumentationIndex: ``` """ + settings = Settings() + + # List of permitted URL prefixes to acquire resources from on demand. + permitted_urls: t.List[str] = [ + "https://cratedb.com/", + "https://github.com/crate", + "https://raw.githubusercontent.com/crate", + ] + def __init__(self): + + # Configure Hishel, an httpx client with caching. + # Define one hour of caching time. + controller = hishel.Controller(allow_stale=True) + storage = hishel.SQLiteStorage(ttl=self.settings.docs_cache_ttl()) + self.client = hishel.CacheClient(controller=controller, storage=storage) + + # Load documentation outline. self.outline = CrateDbKnowledgeOutline.load() - @cachetools.cached(cache={}) + @cachetools.cached(cache=cachetools.TTLCache(maxsize=1, ttl=settings.docs_cache_ttl() - 5)) def items(self): + """ + Return outline items, cached for a little bit less than one hour. + """ return self.outline.find_items().to_dict() + def url_permitted(self, url: str) -> bool: + """ + Validate if a documentation URL is from a permitted domain. + + Only URLs from CrateDB domains and specific GitHub repositories are allowed. + + Args: + url: The URL to validate -def documentation_url_permitted(url: str) -> bool: - return ( - url.startswith("https://cratedb.com/") or - url.startswith("https://github.com/crate") or - url.startswith("https://raw.githubusercontent.com/crate")) + Returns: + bool: True if the URL is from a permitted domain, False otherwise + """ + for permitted_url in self.permitted_urls: + if url.startswith(permitted_url): + return True + return False diff --git a/cratedb_mcp/settings.py b/cratedb_mcp/settings.py index 1ff7152..84c77d1 100644 --- a/cratedb_mcp/settings.py +++ b/cratedb_mcp/settings.py @@ -5,23 +5,44 @@ HTTP_URL: str = os.getenv("CRATEDB_MCP_HTTP_URL", "http://localhost:4200") -# Configure cache lifetime for documentation resources. -DOCS_CACHE_TTL: int = 3600 -try: - DOCS_CACHE_TTL = int(os.getenv("CRATEDB_MCP_DOCS_CACHE_TTL", DOCS_CACHE_TTL)) -except ValueError as e: # pragma: no cover - # If the environment variable is not a valid integer, use the default value, but warn about it. - # TODO: Add software test after refactoring away from module scope. - warnings.warn(f"Environment variable `CRATEDB_MCP_DOCS_CACHE_TTL` invalid: {e}. " - f"Using default value: {DOCS_CACHE_TTL}.", category=UserWarning, stacklevel=2) - # Configure HTTP timeout for all conversations. HTTP_TIMEOUT = 10.0 -# Whether to permit all statements. By default, only SELECT operations are permitted. -PERMIT_ALL_STATEMENTS: bool = to_bool(os.getenv("CRATEDB_MCP_PERMIT_ALL_STATEMENTS", "false")) -# TODO: Refactor into code which is not on the module level. Use OOM early. -if PERMIT_ALL_STATEMENTS: # pragma: no cover - warnings.warn("All types of SQL statements are permitted. This means the LLM " - "agent can write and modify the connected database", category=UserWarning, stacklevel=2) +class Settings: + """ + Application settings bundle. + """ + + @staticmethod + def permit_all_statements() -> bool: + """ + Whether to permit all statements. By default, only SELECT operations are permitted. + """ + permitted = False + try: + permitted = to_bool(os.getenv("CRATEDB_MCP_PERMIT_ALL_STATEMENTS", "false")) + if permitted: + warnings.warn("All types of SQL statements are permitted. " + "This means the LLM agent can write and modify the connected database", + category=UserWarning, stacklevel=2) + except (ValueError, TypeError) as e: + # If the environment variable is not a valid integer, use the default value, but warn about it. + # TODO: Add software test after refactoring away from module scope. + warnings.warn(f"Environment variable `CRATEDB_MCP_PERMIT_ALL_STATEMENTS` invalid: {e}. ", + category=UserWarning, stacklevel=2) + return permitted + + @staticmethod + def docs_cache_ttl(ttl: int = 3600) -> int: + """ + Return cache lifetime for documentation resources, in seconds. + """ + try: + return int(os.getenv("CRATEDB_MCP_DOCS_CACHE_TTL", ttl)) + except ValueError as e: # pragma: no cover + # If the environment variable is not a valid integer, use the default value, but warn about it. + # TODO: Add software test after refactoring away from module scope. + warnings.warn(f"Environment variable `CRATEDB_MCP_DOCS_CACHE_TTL` invalid: {e}. " + f"Using default value: {ttl}.", category=UserWarning, stacklevel=2) + return ttl diff --git a/cratedb_mcp/util/sql.py b/cratedb_mcp/util/sql.py index be7092a..c1ea741 100644 --- a/cratedb_mcp/util/sql.py +++ b/cratedb_mcp/util/sql.py @@ -5,7 +5,7 @@ import sqlparse from sqlparse.tokens import Keyword -from cratedb_mcp.settings import PERMIT_ALL_STATEMENTS +from cratedb_mcp.settings import Settings logger = logging.getLogger(__name__) @@ -21,7 +21,7 @@ def sql_is_permitted(expression: str) -> bool: Issue: https://github.com/crate/cratedb-mcp/issues/10 Question: Does SQLAlchemy provide a solid read-only mode, or any other library? """ - is_dql = SqlStatementClassifier(expression=expression, permit_all=PERMIT_ALL_STATEMENTS).is_dql + is_dql = SqlStatementClassifier(expression=expression, permit_all=Settings.permit_all_statements()).is_dql if is_dql: logger.info(f"Permitted SQL expression: {expression and expression[:50]}...") else: diff --git a/tests/test_util.py b/tests/test_util.py index ac5410d..f74e1e9 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,4 +1,7 @@ -import cratedb_mcp +import os + +import pytest + from cratedb_mcp.util.sql import sql_is_permitted @@ -14,10 +17,20 @@ def test_sql_select_rejected(): assert sql_is_permitted(r"--\; select 42") is False -def test_sql_insert_allowed(mocker): - """When explicitly allowed, permit any kind of statement""" - mocker.patch.object(cratedb_mcp.util.sql, "PERMIT_ALL_STATEMENTS", True) - assert sql_is_permitted("INSERT INTO foobar") is True +def test_sql_insert_permit_success(mocker): + """When explicitly allowed, permit any kind of statement, but verify there is a warning""" + mocker.patch.dict(os.environ, {"CRATEDB_MCP_PERMIT_ALL_STATEMENTS": "true"}) + with pytest.warns(UserWarning) as record: + assert sql_is_permitted("INSERT INTO foobar") is True + assert "All types of SQL statements are permitted" in record[0].message.args[0] + + +def test_sql_insert_permit_invalid(mocker): + """Verify invalid environment variable""" + mocker.patch.dict(os.environ, {"CRATEDB_MCP_PERMIT_ALL_STATEMENTS": "-555"}) + with pytest.warns(UserWarning) as record: + assert sql_is_permitted("INSERT INTO foobar") is False + assert "Environment variable `CRATEDB_MCP_PERMIT_ALL_STATEMENTS` invalid" in record[0].message.args[0] def test_sql_select_multiple_rejected():