From 87980890aee19f4ad0b6c926356b74535a449829 Mon Sep 17 00:00:00 2001 From: Pedro Rodrigues Date: Wed, 13 Aug 2025 19:07:50 +0100 Subject: [PATCH 1/2] pipelines and fixed run_sql --- changelog/0.4.9.md | 3 + src/api/tools/database/__init__.py | 4 +- src/api/tools/database/database.py | 221 +++++++++++++++++- src/api/tools/stage/__init__.py | 5 + src/api/tools/stage/stage.py | 164 +++++++++++++ src/api/tools/tools.py | 5 +- .../tools/test_run_sql_virtual_workspace.py | 101 +++++++- 7 files changed, 482 insertions(+), 21 deletions(-) create mode 100644 src/api/tools/stage/__init__.py create mode 100644 src/api/tools/stage/stage.py diff --git a/changelog/0.4.9.md b/changelog/0.4.9.md index 2944a0e..f3c4e64 100644 --- a/changelog/0.4.9.md +++ b/changelog/0.4.9.md @@ -3,3 +3,6 @@ ## Added - `get_job` tool now supports fetching job details by job ID. +- `run_sql` tool now supports optional username and password parameters for database authentication. +- `upload_file_to_stage` tool to upload a file to SingleStore's driver "Stage" +- `create_pipeline` tool to create a pipeline in SingleStore to process data from a stage to a SingleStore database diff --git a/src/api/tools/database/__init__.py b/src/api/tools/database/__init__.py index 99a0876..ed7b913 100644 --- a/src/api/tools/database/__init__.py +++ b/src/api/tools/database/__init__.py @@ -1,5 +1,5 @@ """Database tools for SingleStore MCP server.""" -from .database import run_sql +from .database import run_sql, create_pipeline -__all__ = ["run_sql"] +__all__ = ["run_sql", "create_pipeline"] diff --git a/src/api/tools/database/database.py b/src/api/tools/database/database.py index 5de7f3f..c126fd8 100644 --- a/src/api/tools/database/database.py +++ b/src/api/tools/database/database.py @@ -31,7 +31,11 @@ class DatabaseCredentials(BaseModel): async def _get_database_credentials( - ctx: Context, target: WorkspaceTarget, database_name: str | None = None + ctx: Context, + target: WorkspaceTarget, + database_name: str | None = None, + provided_username: str | None = None, + provided_password: str | None = None, ) -> tuple[str, str]: """ Get database credentials based on the authentication method. @@ -40,6 +44,8 @@ async def _get_database_credentials( ctx: The MCP context target: The workspace target database_name: The database name to use for key generation + provided_username: Optional username provided by the caller + provided_password: Optional password provided by the caller Returns: Tuple of (username, password) @@ -57,6 +63,11 @@ async def _get_database_credentials( ) if is_using_api_key: + # If credentials are provided directly, use them + if provided_username and provided_password: + logger.debug(f"Using provided credentials for workspace: {target.name}") + return (provided_username, provided_password) + # For API key authentication, we need database credentials # Generate database key using credentials manager credentials_manager = get_session_credentials_manager() @@ -272,8 +283,202 @@ def __init__(self, data): return WorkspaceTarget(target, is_shared) +async def create_pipeline( + ctx: Context, + pipeline_name: str, + data_source: str, + target_table_or_procedure: str, + workspace_id: str, + database: Optional[str] = None, + file_format: str = "CSV", + credentials: Optional[str] = None, + config_options: Optional[str] = None, + max_partitions_per_batch: Optional[int] = None, + field_terminator: str = ",", + field_enclosure: str = '"', + field_escape: str = "\\", + line_terminator: str = "\\n", + line_starting: str = "", + auto_start: bool = True, +) -> Dict[str, Any]: + """ + Create a SQL pipeline for streaming data from a data source to a database table or stored procedure. + + A pipeline is a SQL instruction that creates a stream of data from a source (like S3, Stage, etc.) + to a target table or procedure in SingleStore. This tool generates and executes the CREATE PIPELINE + SQL statement and optionally starts the pipeline. + + Args: + pipeline_name: Name for the pipeline (will be used in CREATE PIPELINE statement) + data_source: Source URL (e.g., S3 URL, Stage URL, etc.) + target_table_or_procedure: Target table name or procedure name (use "PROCEDURE procedure_name" for procedures) + workspace_id: Workspace ID where the pipeline should be created + database: Optional database name to use + file_format: File format (CSV, JSON, etc.) - default is CSV + credentials: Optional credentials string for accessing the data source + config_options: Optional configuration options (e.g., '{"region": "us-east-1"}') + max_partitions_per_batch: Optional max partitions per batch setting + field_terminator: Field terminator for CSV format (default: ",") + field_enclosure: Field enclosure character for CSV (default: '"') + field_escape: Field escape character for CSV (default: "\\") + line_terminator: Line terminator for CSV (default: "\\n") + line_starting: Line starting pattern for CSV (default: "") + auto_start: Whether to automatically start the pipeline after creation (default: True) + + Returns: + Dictionary with pipeline creation status and details + + Example: + # Create pipeline from S3 + pipeline_name = "uk_price_paid" + data_source = "s3://singlestore-docs-example-datasets/pp-monthly/pp-monthly-update-new-version.csv" + target_table_or_procedure = "PROCEDURE process_uk_price_paid" + credentials = "{}" + config_options = '{"region": "us-east-1"}' + + # Create pipeline from Stage + data_source = "stage://datasets/my_data.csv" + target_table_or_procedure = "my_target_table" + """ + # Validate workspace ID format + validated_id = validate_workspace_id(workspace_id) + + await ctx.info( + f"Creating pipeline '{pipeline_name}' from '{data_source}' to '{target_table_or_procedure}' in workspace '{validated_id}'" + ) + + start_time = time.time() + + # Build the CREATE PIPELINE SQL statement + sql_parts = [ + f"CREATE OR REPLACE PIPELINE {pipeline_name} AS", + f" LOAD DATA {data_source.upper() if data_source.startswith(('s3://', 'stage://')) else data_source}", + ] + + # Add credentials if provided + if credentials: + sql_parts.append(f" CREDENTIALS '{credentials}'") + + # Add config options if provided + if config_options: + sql_parts.append(f" CONFIG '{config_options}'") + + # Add max partitions per batch if provided + if max_partitions_per_batch: + sql_parts.append(f" MAX_PARTITIONS_PER_BATCH {max_partitions_per_batch}") + + # Add target (table or procedure) + sql_parts.append(f" INTO {target_table_or_procedure}") + + # Add format and CSV-specific options + sql_parts.append(f" FORMAT {file_format.upper()}") + + if file_format.upper() == "CSV": + sql_parts.extend( + [ + f" FIELDS TERMINATED BY '{field_terminator}' ENCLOSED BY '{field_enclosure}' ESCAPED BY '{field_escape}'", + f" LINES TERMINATED BY '{line_terminator}' STARTING BY '{line_starting}'", + ] + ) + + # Combine all parts into final SQL + create_pipeline_sql = "\n".join(sql_parts) + ";" + + try: + # Execute the CREATE PIPELINE statement + create_result = await run_sql( + ctx=ctx, sql_query=create_pipeline_sql, id=validated_id, database=database + ) + + if create_result.get("status") != "success": + return { + "status": "error", + "message": f"Failed to create pipeline: {create_result.get('message', 'Unknown error')}", + "errorCode": "PIPELINE_CREATION_FAILED", + "errorDetails": create_result, + } + + # If auto_start is True, also start the pipeline + start_result = None + if auto_start: + start_pipeline_sql = f"START PIPELINE IF NOT RUNNING {pipeline_name};" + start_result = await run_sql( + ctx=ctx, + sql_query=start_pipeline_sql, + id=validated_id, + database=database, + ) + + execution_time = (time.time() - start_time) * 1000 + + # Track analytics + settings = config.get_settings() + user_id = config.get_user_id() + settings.analytics_manager.track_event( + user_id, + "tool_calling", + { + "name": "create_pipeline", + "pipeline_name": pipeline_name, + "workspace_id": validated_id, + "auto_start": auto_start, + }, + ) + + # Build success message + success_message = f"Pipeline '{pipeline_name}' created successfully" + if auto_start and start_result and start_result.get("status") == "success": + success_message += " and started" + + return { + "status": "success", + "message": success_message, + "data": { + "pipelineName": pipeline_name, + "dataSource": data_source, + "targetTableOrProcedure": target_table_or_procedure, + "workspaceId": validated_id, + "database": database, + "autoStarted": auto_start + and start_result + and start_result.get("status") == "success", + "createSql": create_pipeline_sql, + "startSql": ( + f"START PIPELINE IF NOT RUNNING {pipeline_name};" + if auto_start + else None + ), + }, + "metadata": { + "executionTimeMs": round(execution_time, 2), + "timestamp": datetime.now().isoformat(), + "fileFormat": file_format, + "creationResult": create_result, + "startResult": start_result, + }, + } + + except Exception as e: + logger.error(f"Error creating pipeline: {str(e)}") + return { + "status": "error", + "message": f"Failed to create pipeline: {str(e)}", + "errorCode": "PIPELINE_CREATION_ERROR", + "errorDetails": { + "exception_type": type(e).__name__, + "pipelineName": pipeline_name, + "workspaceId": validated_id, + }, + } + + async def run_sql( - ctx: Context, sql_query: str, id: str, database: Optional[str] = None + ctx: Context, + sql_query: str, + id: str, + database: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, ) -> Dict[str, Any]: """ Use this tool to execute a single SQL statement against a SingleStore database. @@ -289,6 +494,8 @@ async def run_sql( id: Workspace or starter workspace ID sql_query: The SQL query to execute database: (optional) Database name to use + username: (optional) Database username for authentication. If not provided, will be requested via elicitation for API key auth + password: (optional) Database password for authentication. If not provided, will be requested via elicitation for API key auth Returns: Standardized response with query results and metadata @@ -312,7 +519,9 @@ async def run_sql( # Get database credentials based on authentication method try: - username, password = await _get_database_credentials(ctx, target, database_name) + db_username, db_password = await _get_database_credentials( + ctx, target, database_name, username, password + ) except Exception as e: if "Database credentials required" in str(e): # Handle the specific case where elicitation is not supported @@ -343,8 +552,8 @@ async def run_sql( ctx=ctx, target=target, sql_query=sql_query, - username=username, - password=password, + username=db_username, + password=db_password, database=database_name, ) except Exception as e: @@ -377,7 +586,7 @@ async def run_sql( # Track analytics settings.analytics_manager.track_event( - username, + db_username, "tool_calling", { "name": "run_sql", diff --git a/src/api/tools/stage/__init__.py b/src/api/tools/stage/__init__.py new file mode 100644 index 0000000..f0c87fa --- /dev/null +++ b/src/api/tools/stage/__init__.py @@ -0,0 +1,5 @@ +"""Stage tools for SingleStore MCP server.""" + +from .stage import upload_file_to_stage + +__all__ = ["upload_file_to_stage"] diff --git a/src/api/tools/stage/stage.py b/src/api/tools/stage/stage.py new file mode 100644 index 0000000..0502e37 --- /dev/null +++ b/src/api/tools/stage/stage.py @@ -0,0 +1,164 @@ +"""Stage tools for SingleStore MCP server.""" + +import os +import time +import singlestoredb as s2 +from datetime import datetime, timezone +from typing import Any, Dict + +from mcp.server.fastmcp import Context + +from src.config import config +from src.api.common import get_access_token, get_org_id +from src.logger import get_logger + +# Set up logger for this module +logger = get_logger() + + +async def upload_file_to_stage( + ctx: Context, + local_path: str, + workspace_group_id: str, +) -> Dict[str, Any]: + """ + Upload a file to SingleStore Stage and return the stage URL. + + Stage is SingleStore's file storage driver that allows users to upload files + and get a URL where the file is stored, which can then be used in SQL operations + like creating pipelines. + + Note: Files are limited to a maximum size of 1GB. + + Args: + ctx: MCP context for user interaction + local_path: Local file system path to the file to upload + workspace_group_id: ID of the workspace group to use for Stage access + + Returns: + Dictionary with upload status and stage URL + + Example: + local_path = "/path/to/data.csv" + workspace_group_id = "wg-123abc456def" + """ + settings = config.get_settings() + user_id = config.get_user_id() + settings.analytics_manager.track_event( + user_id, + "tool_calling", + {"name": "upload_file_to_stage"}, + ) + + start_time = time.time() + + try: + # Validate local file exists + if not os.path.exists(local_path): + return { + "status": "error", + "message": f"Local file not found: {local_path}", + "errorCode": "FILE_NOT_FOUND", + } + + # Use the original filename as the stage path + stage_path = os.path.basename(local_path) + + # Get file size for metadata and validation + file_size = os.path.getsize(local_path) + + # Check file size limit (1GB = 1,073,741,824 bytes) + max_file_size = 1024 * 1024 * 1024 # 1GB in bytes + if file_size > max_file_size: + file_size_mb = file_size / (1024 * 1024) + max_size_mb = max_file_size / (1024 * 1024) + return { + "status": "error", + "message": f"File size ({file_size_mb:.2f} MB) exceeds the maximum allowed size of {max_size_mb:.0f} MB (1GB)", + "errorCode": "FILE_TOO_LARGE", + "errorDetails": { + "fileSizeBytes": file_size, + "fileSizeMB": round(file_size_mb, 2), + "maxSizeBytes": max_file_size, + "maxSizeMB": max_size_mb, + }, + } + + await ctx.info( + f"Uploading file '{local_path}' to Stage at '{stage_path}' in workspace group '{workspace_group_id}'..." + ) + + # Get authentication details + access_token = get_access_token() + org_id = get_org_id() + + # Create workspace manager to access workspaces + workspace_manager = s2.manage_workspaces( + access_token=access_token, + base_url=settings.s2_api_base_url, + organization_id=org_id, + ) + + # Find workspace group by ID + workspace_group = workspace_manager.get_workspace_group(id=workspace_group_id) + stage = workspace_group.stage + + # Upload file to Stage + try: + stage_info = stage.upload_file(local_path=local_path, stage_path=stage_path) + + stage_url = stage_info.abspath() + + await ctx.info( + f"File '{local_path}' uploaded to Stage at '{stage_url}' in workspace group '{workspace_group.name}'" + ) + logger.info( + f"File '{local_path}' uploaded to Stage at '{stage_url}' in workspace group '{workspace_group.name}'" + ) + + except Exception as upload_error: + logger.error(f"Stage upload error: {upload_error}") + return { + "status": "error", + "message": f"Failed to upload file to Stage: {str(upload_error)}", + "errorCode": "STAGE_UPLOAD_FAILED", + "errorDetails": { + "localPath": local_path, + "stagePath": stage_path, + "exceptionType": type(upload_error).__name__, + }, + } + + execution_time = (time.time() - start_time) * 1000 + + return { + "status": "success", + "message": f"File uploaded successfully to Stage at '{stage_path}'", + "data": { + "localPath": local_path, + "stagePath": stage_path, + "stageUrl": stage_url, + "fileSize": file_size, + "fileName": os.path.basename(local_path), + "workspaceGroupName": workspace_group.name, + "workspaceGroupId": workspace_group.id, + "stageInfo": stage_info.path if hasattr(stage_info, "path") else None, + }, + "metadata": { + "executionTimeMs": round(execution_time, 2), + "timestamp": datetime.now(timezone.utc).isoformat(), + "uploadedFileSize": file_size, + }, + } + + except Exception as e: + logger.error(f"Error uploading file to Stage: {str(e)}") + return { + "status": "error", + "message": f"Failed to upload file to Stage: {str(e)}", + "errorCode": "STAGE_OPERATION_FAILED", + "errorDetails": { + "exceptionType": type(e).__name__, + "localPath": local_path, + }, + } diff --git a/src/api/tools/tools.py b/src/api/tools/tools.py index 1fc387e..bfcd09b 100644 --- a/src/api/tools/tools.py +++ b/src/api/tools/tools.py @@ -14,7 +14,7 @@ terminate_starter_workspace, ) from src.api.tools.regions import list_regions, list_sharedtier_regions -from src.api.tools.database import run_sql +from src.api.tools.database import run_sql, create_pipeline from src.api.tools.user import get_user_info from src.api.tools.notebooks import ( create_notebook_file, @@ -26,6 +26,7 @@ choose_organization, set_organization, ) +from src.api.tools.stage import upload_file_to_stage # Define the tools with their metadata tools_definition = [ @@ -42,8 +43,10 @@ {"func": list_regions}, {"func": list_sharedtier_regions}, {"func": run_sql}, + {"func": create_pipeline}, {"func": create_notebook_file}, {"func": upload_notebook_file}, + {"func": upload_file_to_stage}, {"func": create_job_from_notebook}, {"func": get_job}, {"func": delete_job}, diff --git a/tests/integration/tools/test_run_sql_virtual_workspace.py b/tests/integration/tools/test_run_sql_virtual_workspace.py index ef46c97..42cdbf6 100644 --- a/tests/integration/tools/test_run_sql_virtual_workspace.py +++ b/tests/integration/tools/test_run_sql_virtual_workspace.py @@ -13,16 +13,25 @@ def random_name(prefix): @pytest_asyncio.fixture -async def workspace_fixture(): - from mcp.server.fastmcp import Context - - ctx = Context() +async def workspace_fixture(mock_context): + ctx = mock_context workspace_name = random_name("testws") database_name = random_name("testdb") + # Get available regions first + regions = build_request("GET", "regions/sharedtier") + if not regions: + raise ValueError("No shared tier regions available") + + # Use the first available region + first_region = regions[0] + + # Create starter workspace payload = { "name": workspace_name, "databaseName": database_name, + "provider": first_region.get("provider"), + "regionName": first_region.get("regionName"), } starter_workspace_data = build_request( "POST", "sharedtier/virtualWorkspaces", data=payload @@ -31,24 +40,92 @@ async def workspace_fixture(): database_name = starter_workspace_data.get("databaseName") assert workspace_id is not None - yield ctx, workspace_id, database_name + # Create a user for the starter workspace + username = random_name("testuser") + user_payload = { + "userName": username, + } + user_data = build_request( + "POST", f"sharedtier/virtualWorkspaces/{workspace_id}/users", data=user_payload + ) + user_id = user_data.get("userID") + user_password = user_data.get("password") + assert user_id is not None + assert user_password is not None + + yield ctx, workspace_id, database_name, username, user_password - build_request("DELETE", f"sharedtier/virtualWorkspaces/{workspace_id}") + # Cleanup: delete the starter workspace + try: + build_request("DELETE", f"sharedtier/virtualWorkspaces/{workspace_id}") + except Exception: + pass # Ignore cleanup errors @pytest.mark.integration class TestRunSQLVirtualWorkspace: @pytest.mark.asyncio - @pytest.mark.skip(reason="Skipping integration test for now.") async def test_run_sql_on_virtual_workspace(self, workspace_fixture): - ctx, workspace_id, database_name = workspace_fixture - sql_query = "SELECT 1 AS test_col" + """Test creating a table and inserting/selecting data.""" + ctx, workspace_id, database_name, username, password = workspace_fixture + + # Create a test table + create_table_sql = """ + CREATE TABLE IF NOT EXISTS test_table ( + id INT PRIMARY KEY, + name VARCHAR(100), + value DECIMAL(10,2) + ) + """ + + result = await run_sql( + ctx=ctx, + sql_query=create_table_sql, + id=workspace_id, + database=database_name, + username=username, + password=password, + ) + assert result["status"] == "success" + + # Insert test data + insert_sql = """ + INSERT INTO test_table (id, name, value) VALUES + (1, 'test_item_1', 10.50), + (2, 'test_item_2', 25.75) + """ + + result = await run_sql( + ctx=ctx, + sql_query=insert_sql, + id=workspace_id, + database=database_name, + username=username, + password=password, + ) + assert result["status"] == "success" + + # Select and verify data + select_sql = "SELECT id, name, value FROM test_table ORDER BY id" + result = await run_sql( ctx=ctx, - sql_query=sql_query, + sql_query=select_sql, id=workspace_id, database=database_name, + username=username, + password=password, ) + assert result["status"] == "success" - assert result["data"]["row_count"] == 1 - assert result["data"]["result"][0]["test_col"] == 1 + assert result["data"]["row_count"] == 2 + + # Verify the data content + rows = result["data"]["result"] + assert rows[0]["id"] == 1 + assert rows[0]["name"] == "test_item_1" + assert float(rows[0]["value"]) == 10.50 + + assert rows[1]["id"] == 2 + assert rows[1]["name"] == "test_item_2" + assert float(rows[1]["value"]) == 25.75 From d8bc75cf28eca1a664a0a4f15b4143b8f3b08be0 Mon Sep 17 00:00:00 2001 From: Pedro Rodrigues Date: Wed, 13 Aug 2025 21:47:43 +0100 Subject: [PATCH 2/2] create pipelines --- changelog/0.4.9.md | 2 - src/api/tools/database/database.py | 158 ++++---- src/api/tools/tools.py | 4 +- src/auth/session_credentials_manager.py | 4 +- tests/integration/tools/test_database.py | 352 ++++++++++++++++++ .../tools/test_run_sql_virtual_workspace.py | 131 ------- tests/unit/test_create_pipeline.py | 197 ++++++++++ 7 files changed, 639 insertions(+), 209 deletions(-) create mode 100644 tests/integration/tools/test_database.py delete mode 100644 tests/integration/tools/test_run_sql_virtual_workspace.py create mode 100644 tests/unit/test_create_pipeline.py diff --git a/changelog/0.4.9.md b/changelog/0.4.9.md index f3c4e64..58752f3 100644 --- a/changelog/0.4.9.md +++ b/changelog/0.4.9.md @@ -4,5 +4,3 @@ - `get_job` tool now supports fetching job details by job ID. - `run_sql` tool now supports optional username and password parameters for database authentication. -- `upload_file_to_stage` tool to upload a file to SingleStore's driver "Stage" -- `create_pipeline` tool to create a pipeline in SingleStore to process data from a stage to a SingleStore database diff --git a/src/api/tools/database/database.py b/src/api/tools/database/database.py index c126fd8..48fd555 100644 --- a/src/api/tools/database/database.py +++ b/src/api/tools/database/database.py @@ -223,7 +223,7 @@ async def __execute_sql_unified( raise -def __get_workspace_by_id(workspace_id: str) -> WorkspaceTarget: +def get_workspace_by_id(workspace_id: str) -> WorkspaceTarget: """ Get a workspace or starter workspace by ID. @@ -290,104 +290,122 @@ async def create_pipeline( target_table_or_procedure: str, workspace_id: str, database: Optional[str] = None, - file_format: str = "CSV", credentials: Optional[str] = None, - config_options: Optional[str] = None, - max_partitions_per_batch: Optional[int] = None, - field_terminator: str = ",", - field_enclosure: str = '"', - field_escape: str = "\\", - line_terminator: str = "\\n", - line_starting: str = "", - auto_start: bool = True, + username: Optional[str] = None, + password: Optional[str] = None, ) -> Dict[str, Any]: """ - Create a SQL pipeline for streaming data from a data source to a database table or stored procedure. + Create a SQL pipeline for streaming CSV data from an S3 data source to a database table or stored procedure. - A pipeline is a SQL instruction that creates a stream of data from a source (like S3, Stage, etc.) - to a target table or procedure in SingleStore. This tool generates and executes the CREATE PIPELINE - SQL statement and optionally starts the pipeline. + This tool is restricted to S3 data sources only and automatically configures the pipeline with: + - CSV format only + - Comma delimiter (,) + - MAX_PARTITIONS_PER_BATCH = 2 + - Auto-start enabled + - No config options Args: pipeline_name: Name for the pipeline (will be used in CREATE PIPELINE statement) - data_source: Source URL (e.g., S3 URL, Stage URL, etc.) + data_source: S3 URL (must start with 's3://') target_table_or_procedure: Target table name or procedure name (use "PROCEDURE procedure_name" for procedures) workspace_id: Workspace ID where the pipeline should be created database: Optional database name to use - file_format: File format (CSV, JSON, etc.) - default is CSV - credentials: Optional credentials string for accessing the data source - config_options: Optional configuration options (e.g., '{"region": "us-east-1"}') - max_partitions_per_batch: Optional max partitions per batch setting - field_terminator: Field terminator for CSV format (default: ",") - field_enclosure: Field enclosure character for CSV (default: '"') - field_escape: Field escape character for CSV (default: "\\") - line_terminator: Line terminator for CSV (default: "\\n") - line_starting: Line starting pattern for CSV (default: "") - auto_start: Whether to automatically start the pipeline after creation (default: True) + credentials: Optional AWS credentials in JSON format for private S3 buckets: + '{"aws_access_key_id": "key", "aws_secret_access_key": "secret", "aws_session_token": "token"}' + Use '{}' for public S3 buckets + username: Optional database username for API key authentication + password: Optional database password for API key authentication Returns: Dictionary with pipeline creation status and details Example: - # Create pipeline from S3 + # Create pipeline from public S3 bucket pipeline_name = "uk_price_paid" data_source = "s3://singlestore-docs-example-datasets/pp-monthly/pp-monthly-update-new-version.csv" - target_table_or_procedure = "PROCEDURE process_uk_price_paid" + target_table_or_procedure = "process_uk_price_paid" credentials = "{}" - config_options = '{"region": "us-east-1"}' - # Create pipeline from Stage - data_source = "stage://datasets/my_data.csv" - target_table_or_procedure = "my_target_table" + # Create pipeline from private S3 bucket + credentials = '{"aws_access_key_id": "AKIAIOSFODNN7EXAMPLE", "aws_secret_access_key": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"}' """ # Validate workspace ID format validated_id = validate_workspace_id(workspace_id) + # Validate that data source is S3 + if not data_source.lower().startswith("s3://"): + return { + "status": "error", + "message": "Only S3 data sources are supported. Data source must start with 's3://'", + "errorCode": "INVALID_DATA_SOURCE", + } + await ctx.info( - f"Creating pipeline '{pipeline_name}' from '{data_source}' to '{target_table_or_procedure}' in workspace '{validated_id}'" + f"Creating S3 pipeline '{pipeline_name}' from '{data_source}' to '{target_table_or_procedure}' in workspace '{validated_id}'" ) start_time = time.time() + # Fixed parameters for restrictions + file_format = "CSV" + max_partitions_per_batch = 2 + field_terminator = "," + field_enclosure = '"' + field_escape = "\\\\" + line_terminator = "\\n" + line_starting = "" + # Build the CREATE PIPELINE SQL statement sql_parts = [ f"CREATE OR REPLACE PIPELINE {pipeline_name} AS", - f" LOAD DATA {data_source.upper() if data_source.startswith(('s3://', 'stage://')) else data_source}", + f" LOAD DATA S3 '{data_source}'", ] # Add credentials if provided if credentials: sql_parts.append(f" CREDENTIALS '{credentials}'") - # Add config options if provided - if config_options: - sql_parts.append(f" CONFIG '{config_options}'") + # Add fixed max partitions per batch + sql_parts.append(f" MAX_PARTITIONS_PER_BATCH {max_partitions_per_batch}") - # Add max partitions per batch if provided - if max_partitions_per_batch: - sql_parts.append(f" MAX_PARTITIONS_PER_BATCH {max_partitions_per_batch}") + # Determine if target is a table or procedure and format accordingly + is_procedure = ( + target_table_or_procedure.upper().startswith("PROCEDURE ") + or "process_" in target_table_or_procedure.lower() + ) - # Add target (table or procedure) - sql_parts.append(f" INTO {target_table_or_procedure}") + if is_procedure: + # Remove "PROCEDURE " prefix if present + procedure_name = target_table_or_procedure + if procedure_name.upper().startswith("PROCEDURE "): + procedure_name = procedure_name[10:] # Remove "PROCEDURE " prefix + sql_parts.append(f" INTO PROCEDURE {procedure_name}") + else: + sql_parts.append(f" INTO TABLE {target_table_or_procedure}") - # Add format and CSV-specific options + # Add fixed format and CSV-specific options sql_parts.append(f" FORMAT {file_format.upper()}") - - if file_format.upper() == "CSV": - sql_parts.extend( - [ - f" FIELDS TERMINATED BY '{field_terminator}' ENCLOSED BY '{field_enclosure}' ESCAPED BY '{field_escape}'", - f" LINES TERMINATED BY '{line_terminator}' STARTING BY '{line_starting}'", - ] - ) + sql_parts.append( + f" FIELDS TERMINATED BY '{field_terminator}' ENCLOSED BY '{field_enclosure}' ESCAPED BY '{field_escape}'" + ) + sql_parts.append( + f" LINES TERMINATED BY '{line_terminator}' STARTING BY '{line_starting}'" + ) # Combine all parts into final SQL create_pipeline_sql = "\n".join(sql_parts) + ";" + logger.info(create_pipeline_sql) + try: # Execute the CREATE PIPELINE statement create_result = await run_sql( - ctx=ctx, sql_query=create_pipeline_sql, id=validated_id, database=database + ctx=ctx, + sql_query=create_pipeline_sql, + id=validated_id, + database=database, + username=username, + password=password, ) if create_result.get("status") != "success": @@ -398,16 +416,17 @@ async def create_pipeline( "errorDetails": create_result, } - # If auto_start is True, also start the pipeline + # If auto_start is True, also start the pipeline (always true in restricted mode) start_result = None - if auto_start: - start_pipeline_sql = f"START PIPELINE IF NOT RUNNING {pipeline_name};" - start_result = await run_sql( - ctx=ctx, - sql_query=start_pipeline_sql, - id=validated_id, - database=database, - ) + start_pipeline_sql = f"START PIPELINE IF NOT RUNNING {pipeline_name};" + start_result = await run_sql( + ctx=ctx, + sql_query=start_pipeline_sql, + id=validated_id, + database=database, + username=username, + password=password, + ) execution_time = (time.time() - start_time) * 1000 @@ -421,13 +440,13 @@ async def create_pipeline( "name": "create_pipeline", "pipeline_name": pipeline_name, "workspace_id": validated_id, - "auto_start": auto_start, + "auto_start": True, }, ) # Build success message success_message = f"Pipeline '{pipeline_name}' created successfully" - if auto_start and start_result and start_result.get("status") == "success": + if start_result and start_result.get("status") == "success": success_message += " and started" return { @@ -439,20 +458,15 @@ async def create_pipeline( "targetTableOrProcedure": target_table_or_procedure, "workspaceId": validated_id, "database": database, - "autoStarted": auto_start - and start_result - and start_result.get("status") == "success", + "autoStarted": start_result and start_result.get("status") == "success", "createSql": create_pipeline_sql, - "startSql": ( - f"START PIPELINE IF NOT RUNNING {pipeline_name};" - if auto_start - else None - ), + "startSql": f"START PIPELINE IF NOT RUNNING {pipeline_name};", }, "metadata": { "executionTimeMs": round(execution_time, 2), "timestamp": datetime.now().isoformat(), - "fileFormat": file_format, + "fileFormat": "CSV", + "maxPartitionsPerBatch": 2, "creationResult": create_result, "startResult": start_result, }, @@ -510,7 +524,7 @@ async def run_sql( settings = config.get_settings() # Target can either be a workspace or a starter workspace - target = __get_workspace_by_id(validated_id) + target = get_workspace_by_id(validated_id) database_name = database # For starter workspaces, use their database name if not specified diff --git a/src/api/tools/tools.py b/src/api/tools/tools.py index bfcd09b..e2762cf 100644 --- a/src/api/tools/tools.py +++ b/src/api/tools/tools.py @@ -43,10 +43,10 @@ {"func": list_regions}, {"func": list_sharedtier_regions}, {"func": run_sql}, - {"func": create_pipeline}, + {"func": create_pipeline, "internal": True}, {"func": create_notebook_file}, {"func": upload_notebook_file}, - {"func": upload_file_to_stage}, + {"func": upload_file_to_stage, "internal": True}, {"func": create_job_from_notebook}, {"func": get_job}, {"func": delete_job}, diff --git a/src/auth/session_credentials_manager.py b/src/auth/session_credentials_manager.py index 20aa646..fc2081a 100644 --- a/src/auth/session_credentials_manager.py +++ b/src/auth/session_credentials_manager.py @@ -138,7 +138,7 @@ def get_session_credentials_manager() -> SessionCredentialsManager: if _session_credentials_manager is None: _session_credentials_manager = SessionCredentialsManager() - logger.info("Created new session credentials manager") + logger.debug("Created new session credentials manager") return _session_credentials_manager @@ -151,7 +151,7 @@ def reset_session_credentials_manager() -> None: """ global _session_credentials_manager _session_credentials_manager = None - logger.info("Reset session credentials manager") + logger.debug("Reset session credentials manager") def invalidate_credentials(database_name: str) -> None: diff --git a/tests/integration/tools/test_database.py b/tests/integration/tools/test_database.py new file mode 100644 index 0000000..b4ef45e --- /dev/null +++ b/tests/integration/tools/test_database.py @@ -0,0 +1,352 @@ +import pytest +import pytest_asyncio +import random +import string +from src.api.tools.database.database import run_sql, create_pipeline +from src.api.common import build_request + + +def random_name(prefix): + return f"{prefix}_" + "".join( + random.choices(string.ascii_lowercase + string.digits, k=8) + ) + + +@pytest_asyncio.fixture +async def workspace_fixture(mock_context): + ctx = mock_context + workspace_name = random_name("testws") + database_name = random_name("testdb") + + # Get available regions first + regions = build_request("GET", "regions/sharedtier") + if not regions: + raise ValueError("No shared tier regions available") + + # Use the first available region + first_region = regions[0] + + # Create starter workspace + payload = { + "name": workspace_name, + "databaseName": database_name, + "provider": first_region.get("provider"), + "regionName": first_region.get("regionName"), + } + starter_workspace_data = build_request( + "POST", "sharedtier/virtualWorkspaces", data=payload + ) + workspace_id = starter_workspace_data.get("virtualWorkspaceID") + database_name = starter_workspace_data.get("databaseName") + assert workspace_id is not None + + # Create a user for the starter workspace + username = random_name("testuser") + user_payload = { + "userName": username, + } + user_data = build_request( + "POST", f"sharedtier/virtualWorkspaces/{workspace_id}/users", data=user_payload + ) + user_id = user_data.get("userID") + user_password = user_data.get("password") + assert user_id is not None + assert user_password is not None + + yield ctx, workspace_id, database_name, username, user_password + + # Cleanup: delete the starter workspace + try: + build_request("DELETE", f"sharedtier/virtualWorkspaces/{workspace_id}") + except Exception: + pass # Ignore cleanup errors + + +@pytest.mark.integration +class TestDatabase: + @pytest.mark.asyncio + async def test_run_sql_on_virtual_workspace(self, workspace_fixture): + """Test creating a table and inserting/selecting data.""" + ctx, workspace_id, database_name, username, password = workspace_fixture + + # Create a test table + create_table_sql = """ + CREATE TABLE IF NOT EXISTS test_table ( + id INT PRIMARY KEY, + name VARCHAR(100), + value DECIMAL(10,2) + ) + """ + + result = await run_sql( + ctx=ctx, + sql_query=create_table_sql, + id=workspace_id, + database=database_name, + username=username, + password=password, + ) + assert result["status"] == "success" + + # Insert test data + insert_sql = """ + INSERT INTO test_table (id, name, value) VALUES + (1, 'test_item_1', 10.50), + (2, 'test_item_2', 25.75) + """ + + result = await run_sql( + ctx=ctx, + sql_query=insert_sql, + id=workspace_id, + database=database_name, + username=username, + password=password, + ) + assert result["status"] == "success" + + # Select and verify data + select_sql = "SELECT id, name, value FROM test_table ORDER BY id" + + result = await run_sql( + ctx=ctx, + sql_query=select_sql, + id=workspace_id, + database=database_name, + username=username, + password=password, + ) + + assert result["status"] == "success" + assert result["data"]["row_count"] == 2 + + # Verify the data content + rows = result["data"]["result"] + assert rows[0]["id"] == 1 + assert rows[0]["name"] == "test_item_1" + assert float(rows[0]["value"]) == 10.50 + + assert rows[1]["id"] == 2 + assert rows[1]["name"] == "test_item_2" + assert float(rows[1]["value"]) == 25.75 + + @pytest.mark.asyncio + async def test_create_pipeline_with_csv_data(self, workspace_fixture): + """Test creating a pipeline that loads CSV data from S3.""" + ctx, workspace_id, database_name, username, password = workspace_fixture + + # Create the target table with proper schema + create_table_sql = """ + CREATE TABLE IF NOT EXISTS uk_price_paid ( + price BIGINT, + date Date, + postcode VARCHAR(100), + type ENUM('terraced', 'semi-detached', 'detached', 'flat', 'other'), + is_new BOOL, + duration ENUM('freehold', 'leasehold', 'unknown'), + addr1 VARCHAR(100), + addr2 VARCHAR(100), + street VARCHAR(100), + locality VARCHAR(100), + town VARCHAR(100), + district VARCHAR(100), + county VARCHAR(100) + ) + """ + + result = await run_sql( + ctx=ctx, + sql_query=create_table_sql, + id=workspace_id, + database=database_name, + username=username, + password=password, + ) + assert result["status"] == "success" + + # Create the stored procedure to process the CSV data + create_procedure_sql = """ + CREATE OR REPLACE PROCEDURE process_uk_price_paid ( + _batch QUERY( + uuid TEXT NOT NULL, + price TEXT NOT NULL, + date TEXT NOT NULL, + postcode TEXT NOT NULL, + type TEXT NOT NULL, + is_new TEXT NOT NULL, + duration TEXT NOT NULL, + addr1 TEXT NOT NULL, + addr2 TEXT NOT NULL, + street TEXT NOT NULL, + locality TEXT NOT NULL, + town TEXT NOT NULL, + district TEXT NOT NULL, + county TEXT NOT NULL, + val1 TEXT NOT NULL, + val2 TEXT NOT NULL + ) + ) + AS + BEGIN + INSERT INTO uk_price_paid ( + price, + date, + postcode, + type, + is_new, + duration, + addr1, + addr2, + street, + locality, + town, + district, + county + ) + SELECT + price, + date, + postcode, + CASE + WHEN type = 'T' THEN 'terraced' + WHEN type = 'S' THEN 'semi-detached' + WHEN type = 'D' THEN 'detached' + WHEN type = 'F' THEN 'flat' + WHEN type = 'O' THEN 'other' + ELSE 'other' + END AS type, + CASE + WHEN is_new = 'Y' THEN TRUE + ELSE FALSE + END AS is_new, + CASE + WHEN duration = 'F' THEN 'freehold' + WHEN duration = 'L' THEN 'leasehold' + WHEN duration = 'U' THEN 'unknown' + ELSE 'unknown' + END AS duration, + addr1, + addr2, + street, + locality, + town, + district, + county + FROM _batch; + END + """ + + result = await run_sql( + ctx=ctx, + sql_query=create_procedure_sql, + id=workspace_id, + database=database_name, + username=username, + password=password, + ) + assert result["status"] == "success" + + # Create pipeline using the procedure + pipeline_name = random_name("uk_price_paid_pipeline") + data_source = "s3://singlestore-docs-example-datasets/pp-monthly/pp-monthly-update-new-version.csv" + + result = await create_pipeline( + ctx=ctx, + pipeline_name=pipeline_name, + data_source=data_source, + target_table_or_procedure="process_uk_price_paid", + workspace_id=workspace_id, + database=database_name, + credentials="{}", + username=username, + password=password, + ) + + assert result["status"] == "success" + assert result["data"]["pipelineName"] == pipeline_name + assert result["data"]["dataSource"] == data_source + assert result["data"]["targetTableOrProcedure"] == "process_uk_price_paid" + assert result["data"]["autoStarted"] is True + + # Wait a moment for pipeline to process some data + import asyncio + + await asyncio.sleep(5) + + # Check that data was loaded + select_sql = "SELECT COUNT(*) as record_count FROM uk_price_paid" + result = await run_sql( + ctx=ctx, + sql_query=select_sql, + id=workspace_id, + database=database_name, + username=username, + password=password, + ) + + assert result["status"] == "success" + record_count = result["data"]["result"][0]["record_count"] + assert record_count > 0, ( + f"Expected some records to be loaded, but got {record_count}" + ) + + # Check sample data structure + sample_sql = "SELECT * FROM uk_price_paid LIMIT 3" + result = await run_sql( + ctx=ctx, + sql_query=sample_sql, + id=workspace_id, + database=database_name, + username=username, + password=password, + ) + + assert result["status"] == "success" + rows = result["data"]["result"] + assert len(rows) > 0 + + # Verify data structure - check that we have expected columns + first_row = rows[0] + expected_columns = [ + "price", + "date", + "postcode", + "type", + "is_new", + ] + for col in expected_columns: + assert col in first_row, f"Expected column '{col}' not found in data" + + # Verify data types and transformations + assert isinstance(first_row["price"], int) + assert first_row["type"] in [ + "terraced", + "semi-detached", + "detached", + "flat", + "other", + ] + + # Stop the pipeline + stop_sql = f"STOP PIPELINE {pipeline_name}" + result = await run_sql( + ctx=ctx, + sql_query=stop_sql, + id=workspace_id, + database=database_name, + username=username, + password=password, + ) + assert result["status"] == "success" + + # Drop the pipeline + drop_sql = f"DROP PIPELINE {pipeline_name}" + result = await run_sql( + ctx=ctx, + sql_query=drop_sql, + id=workspace_id, + database=database_name, + username=username, + password=password, + ) + assert result["status"] == "success" diff --git a/tests/integration/tools/test_run_sql_virtual_workspace.py b/tests/integration/tools/test_run_sql_virtual_workspace.py deleted file mode 100644 index 42cdbf6..0000000 --- a/tests/integration/tools/test_run_sql_virtual_workspace.py +++ /dev/null @@ -1,131 +0,0 @@ -import pytest -import pytest_asyncio -import random -import string -from src.api.tools.database.database import run_sql -from src.api.common import build_request - - -def random_name(prefix): - return f"{prefix}_" + "".join( - random.choices(string.ascii_lowercase + string.digits, k=8) - ) - - -@pytest_asyncio.fixture -async def workspace_fixture(mock_context): - ctx = mock_context - workspace_name = random_name("testws") - database_name = random_name("testdb") - - # Get available regions first - regions = build_request("GET", "regions/sharedtier") - if not regions: - raise ValueError("No shared tier regions available") - - # Use the first available region - first_region = regions[0] - - # Create starter workspace - payload = { - "name": workspace_name, - "databaseName": database_name, - "provider": first_region.get("provider"), - "regionName": first_region.get("regionName"), - } - starter_workspace_data = build_request( - "POST", "sharedtier/virtualWorkspaces", data=payload - ) - workspace_id = starter_workspace_data.get("virtualWorkspaceID") - database_name = starter_workspace_data.get("databaseName") - assert workspace_id is not None - - # Create a user for the starter workspace - username = random_name("testuser") - user_payload = { - "userName": username, - } - user_data = build_request( - "POST", f"sharedtier/virtualWorkspaces/{workspace_id}/users", data=user_payload - ) - user_id = user_data.get("userID") - user_password = user_data.get("password") - assert user_id is not None - assert user_password is not None - - yield ctx, workspace_id, database_name, username, user_password - - # Cleanup: delete the starter workspace - try: - build_request("DELETE", f"sharedtier/virtualWorkspaces/{workspace_id}") - except Exception: - pass # Ignore cleanup errors - - -@pytest.mark.integration -class TestRunSQLVirtualWorkspace: - @pytest.mark.asyncio - async def test_run_sql_on_virtual_workspace(self, workspace_fixture): - """Test creating a table and inserting/selecting data.""" - ctx, workspace_id, database_name, username, password = workspace_fixture - - # Create a test table - create_table_sql = """ - CREATE TABLE IF NOT EXISTS test_table ( - id INT PRIMARY KEY, - name VARCHAR(100), - value DECIMAL(10,2) - ) - """ - - result = await run_sql( - ctx=ctx, - sql_query=create_table_sql, - id=workspace_id, - database=database_name, - username=username, - password=password, - ) - assert result["status"] == "success" - - # Insert test data - insert_sql = """ - INSERT INTO test_table (id, name, value) VALUES - (1, 'test_item_1', 10.50), - (2, 'test_item_2', 25.75) - """ - - result = await run_sql( - ctx=ctx, - sql_query=insert_sql, - id=workspace_id, - database=database_name, - username=username, - password=password, - ) - assert result["status"] == "success" - - # Select and verify data - select_sql = "SELECT id, name, value FROM test_table ORDER BY id" - - result = await run_sql( - ctx=ctx, - sql_query=select_sql, - id=workspace_id, - database=database_name, - username=username, - password=password, - ) - - assert result["status"] == "success" - assert result["data"]["row_count"] == 2 - - # Verify the data content - rows = result["data"]["result"] - assert rows[0]["id"] == 1 - assert rows[0]["name"] == "test_item_1" - assert float(rows[0]["value"]) == 10.50 - - assert rows[1]["id"] == 2 - assert rows[1]["name"] == "test_item_2" - assert float(rows[1]["value"]) == 25.75 diff --git a/tests/unit/test_create_pipeline.py b/tests/unit/test_create_pipeline.py new file mode 100644 index 0000000..73e14be --- /dev/null +++ b/tests/unit/test_create_pipeline.py @@ -0,0 +1,197 @@ +"""Unit tests for the create_pipeline function.""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from src.api.tools.database.database import create_pipeline + + +@pytest.mark.asyncio +async def test_create_pipeline_sql_generation(): + """Test that create_pipeline generates correct SQL statements.""" + # Mock context + ctx = AsyncMock() + ctx.info = AsyncMock() + + # Mock run_sql to return success responses + mock_run_sql_success = AsyncMock(return_value={"status": "success"}) + + with patch("src.api.tools.database.database.run_sql", mock_run_sql_success): + with patch("src.api.tools.database.database.config") as mock_config: + # Mock settings and user_id + mock_settings = MagicMock() + mock_settings.analytics_manager.track_event = MagicMock() + mock_config.get_settings.return_value = mock_settings + mock_config.get_user_id.return_value = "test-user" + + # Test basic pipeline creation + result = await create_pipeline( + ctx=ctx, + pipeline_name="test_pipeline", + data_source="s3://test-bucket/data.csv", + target_table_or_procedure="test_table", + workspace_id="ws-12345678-1234-5678-1234-567812345678", + database="test_db", + file_format="CSV", + credentials="{}", + config_options='{"region": "us-east-1"}', + auto_start=True, + ) + + # Verify the result + assert result["status"] == "success" + assert result["data"]["pipelineName"] == "test_pipeline" + assert result["data"]["dataSource"] == "s3://test-bucket/data.csv" + assert result["data"]["targetTableOrProcedure"] == "test_table" + assert result["data"]["autoStarted"] is True + + # Check that run_sql was called twice (CREATE and START) + assert mock_run_sql_success.call_count == 2 + + # Check the first call (CREATE PIPELINE) + create_call = mock_run_sql_success.call_args_list[0] + create_sql = create_call[1]["sql_query"] + + # Verify CREATE PIPELINE SQL contains expected elements + assert "CREATE OR REPLACE PIPELINE test_pipeline AS" in create_sql + assert "LOAD DATA S3://TEST-BUCKET/DATA.CSV" in create_sql + assert "CREDENTIALS '{}'" in create_sql + assert 'CONFIG \'{"region": "us-east-1"}\'' in create_sql + assert "INTO test_table" in create_sql + assert "FORMAT CSV" in create_sql + assert "FIELDS TERMINATED BY ','" in create_sql + + # Check the second call (START PIPELINE) + start_call = mock_run_sql_success.call_args_list[1] + start_sql = start_call[1]["sql_query"] + assert start_sql == "START PIPELINE IF NOT RUNNING test_pipeline;" + + +@pytest.mark.asyncio +async def test_create_pipeline_with_json_format(): + """Test create_pipeline with JSON format.""" + ctx = AsyncMock() + ctx.info = AsyncMock() + + mock_run_sql_success = AsyncMock(return_value={"status": "success"}) + + with patch("src.api.tools.database.database.run_sql", mock_run_sql_success): + with patch("src.api.tools.database.database.config") as mock_config: + mock_settings = MagicMock() + mock_settings.analytics_manager.track_event = MagicMock() + mock_config.get_settings.return_value = mock_settings + mock_config.get_user_id.return_value = "test-user" + + result = await create_pipeline( + ctx=ctx, + pipeline_name="json_pipeline", + data_source="s3://test-bucket/data.json", + target_table_or_procedure="json_table", + workspace_id="ws-12345678-1234-5678-1234-567812345678", + file_format="JSON", + auto_start=False, + ) + + assert result["status"] == "success" + assert result["data"]["autoStarted"] is False + + # Check that run_sql was called only once (CREATE, no START) + assert mock_run_sql_success.call_count == 1 + + # Check the CREATE PIPELINE SQL + create_call = mock_run_sql_success.call_args_list[0] + create_sql = create_call[1]["sql_query"] + + assert "FORMAT JSON" in create_sql + # JSON format should not have CSV-specific fields + assert "FIELDS TERMINATED BY" not in create_sql + + +@pytest.mark.asyncio +async def test_create_pipeline_with_credentials(): + """Test create_pipeline passes username and password correctly.""" + ctx = AsyncMock() + ctx.info = AsyncMock() + + mock_run_sql_success = AsyncMock(return_value={"status": "success"}) + + with patch("src.api.tools.database.database.run_sql", mock_run_sql_success): + with patch("src.api.tools.database.database.config") as mock_config: + mock_settings = MagicMock() + mock_settings.analytics_manager.track_event = MagicMock() + mock_config.get_settings.return_value = mock_settings + mock_config.get_user_id.return_value = "test-user" + + # Test with username and password + result = await create_pipeline( + ctx=ctx, + pipeline_name="test_pipeline", + data_source="s3://test-bucket/data.csv", + target_table_or_procedure="test_table", + workspace_id="ws-12345678-1234-5678-1234-567812345678", + username="test_user", + password="test_pass", + auto_start=True, + ) + + assert result["status"] == "success" + + # Verify that both run_sql calls received username and password + for call in mock_run_sql_success.call_args_list: + assert call[1]["username"] == "test_user" + assert call[1]["password"] == "test_pass" + + +@pytest.mark.asyncio +async def test_create_pipeline_error_handling(): + """Test create_pipeline handles errors properly.""" + ctx = AsyncMock() + ctx.info = AsyncMock() + + # Mock run_sql to return an error + mock_run_sql_error = AsyncMock( + return_value={"status": "error", "message": "SQL execution failed"} + ) + + with patch("src.api.tools.database.database.run_sql", mock_run_sql_error): + result = await create_pipeline( + ctx=ctx, + pipeline_name="error_pipeline", + data_source="s3://test-bucket/data.csv", + target_table_or_procedure="test_table", + workspace_id="ws-12345678-1234-5678-1234-567812345678", + ) + + assert result["status"] == "error" + assert "Failed to create pipeline" in result["message"] + assert result["errorCode"] == "PIPELINE_CREATION_FAILED" + + +@pytest.mark.asyncio +async def test_create_pipeline_stage_data_source(): + """Test create_pipeline with Stage data source.""" + ctx = AsyncMock() + ctx.info = AsyncMock() + + mock_run_sql_success = AsyncMock(return_value={"status": "success"}) + + with patch("src.api.tools.database.database.run_sql", mock_run_sql_success): + with patch("src.api.tools.database.database.config") as mock_config: + mock_settings = MagicMock() + mock_settings.analytics_manager.track_event = MagicMock() + mock_config.get_settings.return_value = mock_settings + mock_config.get_user_id.return_value = "test-user" + + result = await create_pipeline( + ctx=ctx, + pipeline_name="stage_pipeline", + data_source="stage://my_data/file.csv", + target_table_or_procedure="stage_table", + workspace_id="ws-12345678-1234-5678-1234-567812345678", + ) + + assert result["status"] == "success" + + # Check the CREATE PIPELINE SQL contains proper Stage URL + create_call = mock_run_sql_success.call_args_list[0] + create_sql = create_call[1]["sql_query"] + assert "LOAD DATA STAGE://MY_DATA/FILE.CSV" in create_sql