diff --git a/changelog/0.4.9.md b/changelog/0.4.9.md index 2944a0e..58752f3 100644 --- a/changelog/0.4.9.md +++ b/changelog/0.4.9.md @@ -3,3 +3,4 @@ ## Added - `get_job` tool now supports fetching job details by job ID. +- `run_sql` tool now supports optional username and password parameters for database authentication. diff --git a/src/api/tools/database/__init__.py b/src/api/tools/database/__init__.py index 99a0876..ed7b913 100644 --- a/src/api/tools/database/__init__.py +++ b/src/api/tools/database/__init__.py @@ -1,5 +1,5 @@ """Database tools for SingleStore MCP server.""" -from .database import run_sql +from .database import run_sql, create_pipeline -__all__ = ["run_sql"] +__all__ = ["run_sql", "create_pipeline"] diff --git a/src/api/tools/database/database.py b/src/api/tools/database/database.py index 5de7f3f..48fd555 100644 --- a/src/api/tools/database/database.py +++ b/src/api/tools/database/database.py @@ -31,7 +31,11 @@ class DatabaseCredentials(BaseModel): async def _get_database_credentials( - ctx: Context, target: WorkspaceTarget, database_name: str | None = None + ctx: Context, + target: WorkspaceTarget, + database_name: str | None = None, + provided_username: str | None = None, + provided_password: str | None = None, ) -> tuple[str, str]: """ Get database credentials based on the authentication method. @@ -40,6 +44,8 @@ async def _get_database_credentials( ctx: The MCP context target: The workspace target database_name: The database name to use for key generation + provided_username: Optional username provided by the caller + provided_password: Optional password provided by the caller Returns: Tuple of (username, password) @@ -57,6 +63,11 @@ async def _get_database_credentials( ) if is_using_api_key: + # If credentials are provided directly, use them + if provided_username and provided_password: + logger.debug(f"Using provided credentials for workspace: {target.name}") + return (provided_username, provided_password) + # For API key authentication, we need database credentials # Generate database key using credentials manager credentials_manager = get_session_credentials_manager() @@ -212,7 +223,7 @@ async def __execute_sql_unified( raise -def __get_workspace_by_id(workspace_id: str) -> WorkspaceTarget: +def get_workspace_by_id(workspace_id: str) -> WorkspaceTarget: """ Get a workspace or starter workspace by ID. @@ -272,8 +283,216 @@ def __init__(self, data): return WorkspaceTarget(target, is_shared) +async def create_pipeline( + ctx: Context, + pipeline_name: str, + data_source: str, + target_table_or_procedure: str, + workspace_id: str, + database: Optional[str] = None, + credentials: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, +) -> Dict[str, Any]: + """ + Create a SQL pipeline for streaming CSV data from an S3 data source to a database table or stored procedure. + + This tool is restricted to S3 data sources only and automatically configures the pipeline with: + - CSV format only + - Comma delimiter (,) + - MAX_PARTITIONS_PER_BATCH = 2 + - Auto-start enabled + - No config options + + Args: + pipeline_name: Name for the pipeline (will be used in CREATE PIPELINE statement) + data_source: S3 URL (must start with 's3://') + target_table_or_procedure: Target table name or procedure name (use "PROCEDURE procedure_name" for procedures) + workspace_id: Workspace ID where the pipeline should be created + database: Optional database name to use + credentials: Optional AWS credentials in JSON format for private S3 buckets: + '{"aws_access_key_id": "key", "aws_secret_access_key": "secret", "aws_session_token": "token"}' + Use '{}' for public S3 buckets + username: Optional database username for API key authentication + password: Optional database password for API key authentication + + Returns: + Dictionary with pipeline creation status and details + + Example: + # Create pipeline from public S3 bucket + pipeline_name = "uk_price_paid" + data_source = "s3://singlestore-docs-example-datasets/pp-monthly/pp-monthly-update-new-version.csv" + target_table_or_procedure = "process_uk_price_paid" + credentials = "{}" + + # Create pipeline from private S3 bucket + credentials = '{"aws_access_key_id": "AKIAIOSFODNN7EXAMPLE", "aws_secret_access_key": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"}' + """ + # Validate workspace ID format + validated_id = validate_workspace_id(workspace_id) + + # Validate that data source is S3 + if not data_source.lower().startswith("s3://"): + return { + "status": "error", + "message": "Only S3 data sources are supported. Data source must start with 's3://'", + "errorCode": "INVALID_DATA_SOURCE", + } + + await ctx.info( + f"Creating S3 pipeline '{pipeline_name}' from '{data_source}' to '{target_table_or_procedure}' in workspace '{validated_id}'" + ) + + start_time = time.time() + + # Fixed parameters for restrictions + file_format = "CSV" + max_partitions_per_batch = 2 + field_terminator = "," + field_enclosure = '"' + field_escape = "\\\\" + line_terminator = "\\n" + line_starting = "" + + # Build the CREATE PIPELINE SQL statement + sql_parts = [ + f"CREATE OR REPLACE PIPELINE {pipeline_name} AS", + f" LOAD DATA S3 '{data_source}'", + ] + + # Add credentials if provided + if credentials: + sql_parts.append(f" CREDENTIALS '{credentials}'") + + # Add fixed max partitions per batch + sql_parts.append(f" MAX_PARTITIONS_PER_BATCH {max_partitions_per_batch}") + + # Determine if target is a table or procedure and format accordingly + is_procedure = ( + target_table_or_procedure.upper().startswith("PROCEDURE ") + or "process_" in target_table_or_procedure.lower() + ) + + if is_procedure: + # Remove "PROCEDURE " prefix if present + procedure_name = target_table_or_procedure + if procedure_name.upper().startswith("PROCEDURE "): + procedure_name = procedure_name[10:] # Remove "PROCEDURE " prefix + sql_parts.append(f" INTO PROCEDURE {procedure_name}") + else: + sql_parts.append(f" INTO TABLE {target_table_or_procedure}") + + # Add fixed format and CSV-specific options + sql_parts.append(f" FORMAT {file_format.upper()}") + sql_parts.append( + f" FIELDS TERMINATED BY '{field_terminator}' ENCLOSED BY '{field_enclosure}' ESCAPED BY '{field_escape}'" + ) + sql_parts.append( + f" LINES TERMINATED BY '{line_terminator}' STARTING BY '{line_starting}'" + ) + + # Combine all parts into final SQL + create_pipeline_sql = "\n".join(sql_parts) + ";" + + logger.info(create_pipeline_sql) + + try: + # Execute the CREATE PIPELINE statement + create_result = await run_sql( + ctx=ctx, + sql_query=create_pipeline_sql, + id=validated_id, + database=database, + username=username, + password=password, + ) + + if create_result.get("status") != "success": + return { + "status": "error", + "message": f"Failed to create pipeline: {create_result.get('message', 'Unknown error')}", + "errorCode": "PIPELINE_CREATION_FAILED", + "errorDetails": create_result, + } + + # If auto_start is True, also start the pipeline (always true in restricted mode) + start_result = None + start_pipeline_sql = f"START PIPELINE IF NOT RUNNING {pipeline_name};" + start_result = await run_sql( + ctx=ctx, + sql_query=start_pipeline_sql, + id=validated_id, + database=database, + username=username, + password=password, + ) + + execution_time = (time.time() - start_time) * 1000 + + # Track analytics + settings = config.get_settings() + user_id = config.get_user_id() + settings.analytics_manager.track_event( + user_id, + "tool_calling", + { + "name": "create_pipeline", + "pipeline_name": pipeline_name, + "workspace_id": validated_id, + "auto_start": True, + }, + ) + + # Build success message + success_message = f"Pipeline '{pipeline_name}' created successfully" + if start_result and start_result.get("status") == "success": + success_message += " and started" + + return { + "status": "success", + "message": success_message, + "data": { + "pipelineName": pipeline_name, + "dataSource": data_source, + "targetTableOrProcedure": target_table_or_procedure, + "workspaceId": validated_id, + "database": database, + "autoStarted": start_result and start_result.get("status") == "success", + "createSql": create_pipeline_sql, + "startSql": f"START PIPELINE IF NOT RUNNING {pipeline_name};", + }, + "metadata": { + "executionTimeMs": round(execution_time, 2), + "timestamp": datetime.now().isoformat(), + "fileFormat": "CSV", + "maxPartitionsPerBatch": 2, + "creationResult": create_result, + "startResult": start_result, + }, + } + + except Exception as e: + logger.error(f"Error creating pipeline: {str(e)}") + return { + "status": "error", + "message": f"Failed to create pipeline: {str(e)}", + "errorCode": "PIPELINE_CREATION_ERROR", + "errorDetails": { + "exception_type": type(e).__name__, + "pipelineName": pipeline_name, + "workspaceId": validated_id, + }, + } + + async def run_sql( - ctx: Context, sql_query: str, id: str, database: Optional[str] = None + ctx: Context, + sql_query: str, + id: str, + database: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, ) -> Dict[str, Any]: """ Use this tool to execute a single SQL statement against a SingleStore database. @@ -289,6 +508,8 @@ async def run_sql( id: Workspace or starter workspace ID sql_query: The SQL query to execute database: (optional) Database name to use + username: (optional) Database username for authentication. If not provided, will be requested via elicitation for API key auth + password: (optional) Database password for authentication. If not provided, will be requested via elicitation for API key auth Returns: Standardized response with query results and metadata @@ -303,7 +524,7 @@ async def run_sql( settings = config.get_settings() # Target can either be a workspace or a starter workspace - target = __get_workspace_by_id(validated_id) + target = get_workspace_by_id(validated_id) database_name = database # For starter workspaces, use their database name if not specified @@ -312,7 +533,9 @@ async def run_sql( # Get database credentials based on authentication method try: - username, password = await _get_database_credentials(ctx, target, database_name) + db_username, db_password = await _get_database_credentials( + ctx, target, database_name, username, password + ) except Exception as e: if "Database credentials required" in str(e): # Handle the specific case where elicitation is not supported @@ -343,8 +566,8 @@ async def run_sql( ctx=ctx, target=target, sql_query=sql_query, - username=username, - password=password, + username=db_username, + password=db_password, database=database_name, ) except Exception as e: @@ -377,7 +600,7 @@ async def run_sql( # Track analytics settings.analytics_manager.track_event( - username, + db_username, "tool_calling", { "name": "run_sql", diff --git a/src/api/tools/stage/__init__.py b/src/api/tools/stage/__init__.py new file mode 100644 index 0000000..f0c87fa --- /dev/null +++ b/src/api/tools/stage/__init__.py @@ -0,0 +1,5 @@ +"""Stage tools for SingleStore MCP server.""" + +from .stage import upload_file_to_stage + +__all__ = ["upload_file_to_stage"] diff --git a/src/api/tools/stage/stage.py b/src/api/tools/stage/stage.py new file mode 100644 index 0000000..0502e37 --- /dev/null +++ b/src/api/tools/stage/stage.py @@ -0,0 +1,164 @@ +"""Stage tools for SingleStore MCP server.""" + +import os +import time +import singlestoredb as s2 +from datetime import datetime, timezone +from typing import Any, Dict + +from mcp.server.fastmcp import Context + +from src.config import config +from src.api.common import get_access_token, get_org_id +from src.logger import get_logger + +# Set up logger for this module +logger = get_logger() + + +async def upload_file_to_stage( + ctx: Context, + local_path: str, + workspace_group_id: str, +) -> Dict[str, Any]: + """ + Upload a file to SingleStore Stage and return the stage URL. + + Stage is SingleStore's file storage driver that allows users to upload files + and get a URL where the file is stored, which can then be used in SQL operations + like creating pipelines. + + Note: Files are limited to a maximum size of 1GB. + + Args: + ctx: MCP context for user interaction + local_path: Local file system path to the file to upload + workspace_group_id: ID of the workspace group to use for Stage access + + Returns: + Dictionary with upload status and stage URL + + Example: + local_path = "/path/to/data.csv" + workspace_group_id = "wg-123abc456def" + """ + settings = config.get_settings() + user_id = config.get_user_id() + settings.analytics_manager.track_event( + user_id, + "tool_calling", + {"name": "upload_file_to_stage"}, + ) + + start_time = time.time() + + try: + # Validate local file exists + if not os.path.exists(local_path): + return { + "status": "error", + "message": f"Local file not found: {local_path}", + "errorCode": "FILE_NOT_FOUND", + } + + # Use the original filename as the stage path + stage_path = os.path.basename(local_path) + + # Get file size for metadata and validation + file_size = os.path.getsize(local_path) + + # Check file size limit (1GB = 1,073,741,824 bytes) + max_file_size = 1024 * 1024 * 1024 # 1GB in bytes + if file_size > max_file_size: + file_size_mb = file_size / (1024 * 1024) + max_size_mb = max_file_size / (1024 * 1024) + return { + "status": "error", + "message": f"File size ({file_size_mb:.2f} MB) exceeds the maximum allowed size of {max_size_mb:.0f} MB (1GB)", + "errorCode": "FILE_TOO_LARGE", + "errorDetails": { + "fileSizeBytes": file_size, + "fileSizeMB": round(file_size_mb, 2), + "maxSizeBytes": max_file_size, + "maxSizeMB": max_size_mb, + }, + } + + await ctx.info( + f"Uploading file '{local_path}' to Stage at '{stage_path}' in workspace group '{workspace_group_id}'..." + ) + + # Get authentication details + access_token = get_access_token() + org_id = get_org_id() + + # Create workspace manager to access workspaces + workspace_manager = s2.manage_workspaces( + access_token=access_token, + base_url=settings.s2_api_base_url, + organization_id=org_id, + ) + + # Find workspace group by ID + workspace_group = workspace_manager.get_workspace_group(id=workspace_group_id) + stage = workspace_group.stage + + # Upload file to Stage + try: + stage_info = stage.upload_file(local_path=local_path, stage_path=stage_path) + + stage_url = stage_info.abspath() + + await ctx.info( + f"File '{local_path}' uploaded to Stage at '{stage_url}' in workspace group '{workspace_group.name}'" + ) + logger.info( + f"File '{local_path}' uploaded to Stage at '{stage_url}' in workspace group '{workspace_group.name}'" + ) + + except Exception as upload_error: + logger.error(f"Stage upload error: {upload_error}") + return { + "status": "error", + "message": f"Failed to upload file to Stage: {str(upload_error)}", + "errorCode": "STAGE_UPLOAD_FAILED", + "errorDetails": { + "localPath": local_path, + "stagePath": stage_path, + "exceptionType": type(upload_error).__name__, + }, + } + + execution_time = (time.time() - start_time) * 1000 + + return { + "status": "success", + "message": f"File uploaded successfully to Stage at '{stage_path}'", + "data": { + "localPath": local_path, + "stagePath": stage_path, + "stageUrl": stage_url, + "fileSize": file_size, + "fileName": os.path.basename(local_path), + "workspaceGroupName": workspace_group.name, + "workspaceGroupId": workspace_group.id, + "stageInfo": stage_info.path if hasattr(stage_info, "path") else None, + }, + "metadata": { + "executionTimeMs": round(execution_time, 2), + "timestamp": datetime.now(timezone.utc).isoformat(), + "uploadedFileSize": file_size, + }, + } + + except Exception as e: + logger.error(f"Error uploading file to Stage: {str(e)}") + return { + "status": "error", + "message": f"Failed to upload file to Stage: {str(e)}", + "errorCode": "STAGE_OPERATION_FAILED", + "errorDetails": { + "exceptionType": type(e).__name__, + "localPath": local_path, + }, + } diff --git a/src/api/tools/tools.py b/src/api/tools/tools.py index 1fc387e..e2762cf 100644 --- a/src/api/tools/tools.py +++ b/src/api/tools/tools.py @@ -14,7 +14,7 @@ terminate_starter_workspace, ) from src.api.tools.regions import list_regions, list_sharedtier_regions -from src.api.tools.database import run_sql +from src.api.tools.database import run_sql, create_pipeline from src.api.tools.user import get_user_info from src.api.tools.notebooks import ( create_notebook_file, @@ -26,6 +26,7 @@ choose_organization, set_organization, ) +from src.api.tools.stage import upload_file_to_stage # Define the tools with their metadata tools_definition = [ @@ -42,8 +43,10 @@ {"func": list_regions}, {"func": list_sharedtier_regions}, {"func": run_sql}, + {"func": create_pipeline, "internal": True}, {"func": create_notebook_file}, {"func": upload_notebook_file}, + {"func": upload_file_to_stage, "internal": True}, {"func": create_job_from_notebook}, {"func": get_job}, {"func": delete_job}, diff --git a/src/auth/session_credentials_manager.py b/src/auth/session_credentials_manager.py index 20aa646..fc2081a 100644 --- a/src/auth/session_credentials_manager.py +++ b/src/auth/session_credentials_manager.py @@ -138,7 +138,7 @@ def get_session_credentials_manager() -> SessionCredentialsManager: if _session_credentials_manager is None: _session_credentials_manager = SessionCredentialsManager() - logger.info("Created new session credentials manager") + logger.debug("Created new session credentials manager") return _session_credentials_manager @@ -151,7 +151,7 @@ def reset_session_credentials_manager() -> None: """ global _session_credentials_manager _session_credentials_manager = None - logger.info("Reset session credentials manager") + logger.debug("Reset session credentials manager") def invalidate_credentials(database_name: str) -> None: diff --git a/tests/integration/tools/test_database.py b/tests/integration/tools/test_database.py new file mode 100644 index 0000000..b4ef45e --- /dev/null +++ b/tests/integration/tools/test_database.py @@ -0,0 +1,352 @@ +import pytest +import pytest_asyncio +import random +import string +from src.api.tools.database.database import run_sql, create_pipeline +from src.api.common import build_request + + +def random_name(prefix): + return f"{prefix}_" + "".join( + random.choices(string.ascii_lowercase + string.digits, k=8) + ) + + +@pytest_asyncio.fixture +async def workspace_fixture(mock_context): + ctx = mock_context + workspace_name = random_name("testws") + database_name = random_name("testdb") + + # Get available regions first + regions = build_request("GET", "regions/sharedtier") + if not regions: + raise ValueError("No shared tier regions available") + + # Use the first available region + first_region = regions[0] + + # Create starter workspace + payload = { + "name": workspace_name, + "databaseName": database_name, + "provider": first_region.get("provider"), + "regionName": first_region.get("regionName"), + } + starter_workspace_data = build_request( + "POST", "sharedtier/virtualWorkspaces", data=payload + ) + workspace_id = starter_workspace_data.get("virtualWorkspaceID") + database_name = starter_workspace_data.get("databaseName") + assert workspace_id is not None + + # Create a user for the starter workspace + username = random_name("testuser") + user_payload = { + "userName": username, + } + user_data = build_request( + "POST", f"sharedtier/virtualWorkspaces/{workspace_id}/users", data=user_payload + ) + user_id = user_data.get("userID") + user_password = user_data.get("password") + assert user_id is not None + assert user_password is not None + + yield ctx, workspace_id, database_name, username, user_password + + # Cleanup: delete the starter workspace + try: + build_request("DELETE", f"sharedtier/virtualWorkspaces/{workspace_id}") + except Exception: + pass # Ignore cleanup errors + + +@pytest.mark.integration +class TestDatabase: + @pytest.mark.asyncio + async def test_run_sql_on_virtual_workspace(self, workspace_fixture): + """Test creating a table and inserting/selecting data.""" + ctx, workspace_id, database_name, username, password = workspace_fixture + + # Create a test table + create_table_sql = """ + CREATE TABLE IF NOT EXISTS test_table ( + id INT PRIMARY KEY, + name VARCHAR(100), + value DECIMAL(10,2) + ) + """ + + result = await run_sql( + ctx=ctx, + sql_query=create_table_sql, + id=workspace_id, + database=database_name, + username=username, + password=password, + ) + assert result["status"] == "success" + + # Insert test data + insert_sql = """ + INSERT INTO test_table (id, name, value) VALUES + (1, 'test_item_1', 10.50), + (2, 'test_item_2', 25.75) + """ + + result = await run_sql( + ctx=ctx, + sql_query=insert_sql, + id=workspace_id, + database=database_name, + username=username, + password=password, + ) + assert result["status"] == "success" + + # Select and verify data + select_sql = "SELECT id, name, value FROM test_table ORDER BY id" + + result = await run_sql( + ctx=ctx, + sql_query=select_sql, + id=workspace_id, + database=database_name, + username=username, + password=password, + ) + + assert result["status"] == "success" + assert result["data"]["row_count"] == 2 + + # Verify the data content + rows = result["data"]["result"] + assert rows[0]["id"] == 1 + assert rows[0]["name"] == "test_item_1" + assert float(rows[0]["value"]) == 10.50 + + assert rows[1]["id"] == 2 + assert rows[1]["name"] == "test_item_2" + assert float(rows[1]["value"]) == 25.75 + + @pytest.mark.asyncio + async def test_create_pipeline_with_csv_data(self, workspace_fixture): + """Test creating a pipeline that loads CSV data from S3.""" + ctx, workspace_id, database_name, username, password = workspace_fixture + + # Create the target table with proper schema + create_table_sql = """ + CREATE TABLE IF NOT EXISTS uk_price_paid ( + price BIGINT, + date Date, + postcode VARCHAR(100), + type ENUM('terraced', 'semi-detached', 'detached', 'flat', 'other'), + is_new BOOL, + duration ENUM('freehold', 'leasehold', 'unknown'), + addr1 VARCHAR(100), + addr2 VARCHAR(100), + street VARCHAR(100), + locality VARCHAR(100), + town VARCHAR(100), + district VARCHAR(100), + county VARCHAR(100) + ) + """ + + result = await run_sql( + ctx=ctx, + sql_query=create_table_sql, + id=workspace_id, + database=database_name, + username=username, + password=password, + ) + assert result["status"] == "success" + + # Create the stored procedure to process the CSV data + create_procedure_sql = """ + CREATE OR REPLACE PROCEDURE process_uk_price_paid ( + _batch QUERY( + uuid TEXT NOT NULL, + price TEXT NOT NULL, + date TEXT NOT NULL, + postcode TEXT NOT NULL, + type TEXT NOT NULL, + is_new TEXT NOT NULL, + duration TEXT NOT NULL, + addr1 TEXT NOT NULL, + addr2 TEXT NOT NULL, + street TEXT NOT NULL, + locality TEXT NOT NULL, + town TEXT NOT NULL, + district TEXT NOT NULL, + county TEXT NOT NULL, + val1 TEXT NOT NULL, + val2 TEXT NOT NULL + ) + ) + AS + BEGIN + INSERT INTO uk_price_paid ( + price, + date, + postcode, + type, + is_new, + duration, + addr1, + addr2, + street, + locality, + town, + district, + county + ) + SELECT + price, + date, + postcode, + CASE + WHEN type = 'T' THEN 'terraced' + WHEN type = 'S' THEN 'semi-detached' + WHEN type = 'D' THEN 'detached' + WHEN type = 'F' THEN 'flat' + WHEN type = 'O' THEN 'other' + ELSE 'other' + END AS type, + CASE + WHEN is_new = 'Y' THEN TRUE + ELSE FALSE + END AS is_new, + CASE + WHEN duration = 'F' THEN 'freehold' + WHEN duration = 'L' THEN 'leasehold' + WHEN duration = 'U' THEN 'unknown' + ELSE 'unknown' + END AS duration, + addr1, + addr2, + street, + locality, + town, + district, + county + FROM _batch; + END + """ + + result = await run_sql( + ctx=ctx, + sql_query=create_procedure_sql, + id=workspace_id, + database=database_name, + username=username, + password=password, + ) + assert result["status"] == "success" + + # Create pipeline using the procedure + pipeline_name = random_name("uk_price_paid_pipeline") + data_source = "s3://singlestore-docs-example-datasets/pp-monthly/pp-monthly-update-new-version.csv" + + result = await create_pipeline( + ctx=ctx, + pipeline_name=pipeline_name, + data_source=data_source, + target_table_or_procedure="process_uk_price_paid", + workspace_id=workspace_id, + database=database_name, + credentials="{}", + username=username, + password=password, + ) + + assert result["status"] == "success" + assert result["data"]["pipelineName"] == pipeline_name + assert result["data"]["dataSource"] == data_source + assert result["data"]["targetTableOrProcedure"] == "process_uk_price_paid" + assert result["data"]["autoStarted"] is True + + # Wait a moment for pipeline to process some data + import asyncio + + await asyncio.sleep(5) + + # Check that data was loaded + select_sql = "SELECT COUNT(*) as record_count FROM uk_price_paid" + result = await run_sql( + ctx=ctx, + sql_query=select_sql, + id=workspace_id, + database=database_name, + username=username, + password=password, + ) + + assert result["status"] == "success" + record_count = result["data"]["result"][0]["record_count"] + assert record_count > 0, ( + f"Expected some records to be loaded, but got {record_count}" + ) + + # Check sample data structure + sample_sql = "SELECT * FROM uk_price_paid LIMIT 3" + result = await run_sql( + ctx=ctx, + sql_query=sample_sql, + id=workspace_id, + database=database_name, + username=username, + password=password, + ) + + assert result["status"] == "success" + rows = result["data"]["result"] + assert len(rows) > 0 + + # Verify data structure - check that we have expected columns + first_row = rows[0] + expected_columns = [ + "price", + "date", + "postcode", + "type", + "is_new", + ] + for col in expected_columns: + assert col in first_row, f"Expected column '{col}' not found in data" + + # Verify data types and transformations + assert isinstance(first_row["price"], int) + assert first_row["type"] in [ + "terraced", + "semi-detached", + "detached", + "flat", + "other", + ] + + # Stop the pipeline + stop_sql = f"STOP PIPELINE {pipeline_name}" + result = await run_sql( + ctx=ctx, + sql_query=stop_sql, + id=workspace_id, + database=database_name, + username=username, + password=password, + ) + assert result["status"] == "success" + + # Drop the pipeline + drop_sql = f"DROP PIPELINE {pipeline_name}" + result = await run_sql( + ctx=ctx, + sql_query=drop_sql, + id=workspace_id, + database=database_name, + username=username, + password=password, + ) + assert result["status"] == "success" diff --git a/tests/integration/tools/test_run_sql_virtual_workspace.py b/tests/integration/tools/test_run_sql_virtual_workspace.py deleted file mode 100644 index ef46c97..0000000 --- a/tests/integration/tools/test_run_sql_virtual_workspace.py +++ /dev/null @@ -1,54 +0,0 @@ -import pytest -import pytest_asyncio -import random -import string -from src.api.tools.database.database import run_sql -from src.api.common import build_request - - -def random_name(prefix): - return f"{prefix}_" + "".join( - random.choices(string.ascii_lowercase + string.digits, k=8) - ) - - -@pytest_asyncio.fixture -async def workspace_fixture(): - from mcp.server.fastmcp import Context - - ctx = Context() - workspace_name = random_name("testws") - database_name = random_name("testdb") - - payload = { - "name": workspace_name, - "databaseName": database_name, - } - starter_workspace_data = build_request( - "POST", "sharedtier/virtualWorkspaces", data=payload - ) - workspace_id = starter_workspace_data.get("virtualWorkspaceID") - database_name = starter_workspace_data.get("databaseName") - assert workspace_id is not None - - yield ctx, workspace_id, database_name - - build_request("DELETE", f"sharedtier/virtualWorkspaces/{workspace_id}") - - -@pytest.mark.integration -class TestRunSQLVirtualWorkspace: - @pytest.mark.asyncio - @pytest.mark.skip(reason="Skipping integration test for now.") - async def test_run_sql_on_virtual_workspace(self, workspace_fixture): - ctx, workspace_id, database_name = workspace_fixture - sql_query = "SELECT 1 AS test_col" - result = await run_sql( - ctx=ctx, - sql_query=sql_query, - id=workspace_id, - database=database_name, - ) - assert result["status"] == "success" - assert result["data"]["row_count"] == 1 - assert result["data"]["result"][0]["test_col"] == 1 diff --git a/tests/unit/test_create_pipeline.py b/tests/unit/test_create_pipeline.py new file mode 100644 index 0000000..73e14be --- /dev/null +++ b/tests/unit/test_create_pipeline.py @@ -0,0 +1,197 @@ +"""Unit tests for the create_pipeline function.""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from src.api.tools.database.database import create_pipeline + + +@pytest.mark.asyncio +async def test_create_pipeline_sql_generation(): + """Test that create_pipeline generates correct SQL statements.""" + # Mock context + ctx = AsyncMock() + ctx.info = AsyncMock() + + # Mock run_sql to return success responses + mock_run_sql_success = AsyncMock(return_value={"status": "success"}) + + with patch("src.api.tools.database.database.run_sql", mock_run_sql_success): + with patch("src.api.tools.database.database.config") as mock_config: + # Mock settings and user_id + mock_settings = MagicMock() + mock_settings.analytics_manager.track_event = MagicMock() + mock_config.get_settings.return_value = mock_settings + mock_config.get_user_id.return_value = "test-user" + + # Test basic pipeline creation + result = await create_pipeline( + ctx=ctx, + pipeline_name="test_pipeline", + data_source="s3://test-bucket/data.csv", + target_table_or_procedure="test_table", + workspace_id="ws-12345678-1234-5678-1234-567812345678", + database="test_db", + file_format="CSV", + credentials="{}", + config_options='{"region": "us-east-1"}', + auto_start=True, + ) + + # Verify the result + assert result["status"] == "success" + assert result["data"]["pipelineName"] == "test_pipeline" + assert result["data"]["dataSource"] == "s3://test-bucket/data.csv" + assert result["data"]["targetTableOrProcedure"] == "test_table" + assert result["data"]["autoStarted"] is True + + # Check that run_sql was called twice (CREATE and START) + assert mock_run_sql_success.call_count == 2 + + # Check the first call (CREATE PIPELINE) + create_call = mock_run_sql_success.call_args_list[0] + create_sql = create_call[1]["sql_query"] + + # Verify CREATE PIPELINE SQL contains expected elements + assert "CREATE OR REPLACE PIPELINE test_pipeline AS" in create_sql + assert "LOAD DATA S3://TEST-BUCKET/DATA.CSV" in create_sql + assert "CREDENTIALS '{}'" in create_sql + assert 'CONFIG \'{"region": "us-east-1"}\'' in create_sql + assert "INTO test_table" in create_sql + assert "FORMAT CSV" in create_sql + assert "FIELDS TERMINATED BY ','" in create_sql + + # Check the second call (START PIPELINE) + start_call = mock_run_sql_success.call_args_list[1] + start_sql = start_call[1]["sql_query"] + assert start_sql == "START PIPELINE IF NOT RUNNING test_pipeline;" + + +@pytest.mark.asyncio +async def test_create_pipeline_with_json_format(): + """Test create_pipeline with JSON format.""" + ctx = AsyncMock() + ctx.info = AsyncMock() + + mock_run_sql_success = AsyncMock(return_value={"status": "success"}) + + with patch("src.api.tools.database.database.run_sql", mock_run_sql_success): + with patch("src.api.tools.database.database.config") as mock_config: + mock_settings = MagicMock() + mock_settings.analytics_manager.track_event = MagicMock() + mock_config.get_settings.return_value = mock_settings + mock_config.get_user_id.return_value = "test-user" + + result = await create_pipeline( + ctx=ctx, + pipeline_name="json_pipeline", + data_source="s3://test-bucket/data.json", + target_table_or_procedure="json_table", + workspace_id="ws-12345678-1234-5678-1234-567812345678", + file_format="JSON", + auto_start=False, + ) + + assert result["status"] == "success" + assert result["data"]["autoStarted"] is False + + # Check that run_sql was called only once (CREATE, no START) + assert mock_run_sql_success.call_count == 1 + + # Check the CREATE PIPELINE SQL + create_call = mock_run_sql_success.call_args_list[0] + create_sql = create_call[1]["sql_query"] + + assert "FORMAT JSON" in create_sql + # JSON format should not have CSV-specific fields + assert "FIELDS TERMINATED BY" not in create_sql + + +@pytest.mark.asyncio +async def test_create_pipeline_with_credentials(): + """Test create_pipeline passes username and password correctly.""" + ctx = AsyncMock() + ctx.info = AsyncMock() + + mock_run_sql_success = AsyncMock(return_value={"status": "success"}) + + with patch("src.api.tools.database.database.run_sql", mock_run_sql_success): + with patch("src.api.tools.database.database.config") as mock_config: + mock_settings = MagicMock() + mock_settings.analytics_manager.track_event = MagicMock() + mock_config.get_settings.return_value = mock_settings + mock_config.get_user_id.return_value = "test-user" + + # Test with username and password + result = await create_pipeline( + ctx=ctx, + pipeline_name="test_pipeline", + data_source="s3://test-bucket/data.csv", + target_table_or_procedure="test_table", + workspace_id="ws-12345678-1234-5678-1234-567812345678", + username="test_user", + password="test_pass", + auto_start=True, + ) + + assert result["status"] == "success" + + # Verify that both run_sql calls received username and password + for call in mock_run_sql_success.call_args_list: + assert call[1]["username"] == "test_user" + assert call[1]["password"] == "test_pass" + + +@pytest.mark.asyncio +async def test_create_pipeline_error_handling(): + """Test create_pipeline handles errors properly.""" + ctx = AsyncMock() + ctx.info = AsyncMock() + + # Mock run_sql to return an error + mock_run_sql_error = AsyncMock( + return_value={"status": "error", "message": "SQL execution failed"} + ) + + with patch("src.api.tools.database.database.run_sql", mock_run_sql_error): + result = await create_pipeline( + ctx=ctx, + pipeline_name="error_pipeline", + data_source="s3://test-bucket/data.csv", + target_table_or_procedure="test_table", + workspace_id="ws-12345678-1234-5678-1234-567812345678", + ) + + assert result["status"] == "error" + assert "Failed to create pipeline" in result["message"] + assert result["errorCode"] == "PIPELINE_CREATION_FAILED" + + +@pytest.mark.asyncio +async def test_create_pipeline_stage_data_source(): + """Test create_pipeline with Stage data source.""" + ctx = AsyncMock() + ctx.info = AsyncMock() + + mock_run_sql_success = AsyncMock(return_value={"status": "success"}) + + with patch("src.api.tools.database.database.run_sql", mock_run_sql_success): + with patch("src.api.tools.database.database.config") as mock_config: + mock_settings = MagicMock() + mock_settings.analytics_manager.track_event = MagicMock() + mock_config.get_settings.return_value = mock_settings + mock_config.get_user_id.return_value = "test-user" + + result = await create_pipeline( + ctx=ctx, + pipeline_name="stage_pipeline", + data_source="stage://my_data/file.csv", + target_table_or_procedure="stage_table", + workspace_id="ws-12345678-1234-5678-1234-567812345678", + ) + + assert result["status"] == "success" + + # Check the CREATE PIPELINE SQL contains proper Stage URL + create_call = mock_run_sql_success.call_args_list[0] + create_sql = create_call[1]["sql_query"] + assert "LOAD DATA STAGE://MY_DATA/FILE.CSV" in create_sql