diff --git a/dlt/destinations/impl/snowflake/utils.py b/dlt/destinations/impl/snowflake/utils.py index 208bb4f27d..bcbe9c7b3f 100644 --- a/dlt/destinations/impl/snowflake/utils.py +++ b/dlt/destinations/impl/snowflake/utils.py @@ -1,6 +1,5 @@ -from typing import Literal, Optional, Dict, Any, Union, Tuple +from typing import Optional, Dict, Any from urllib.parse import urlparse, urlunparse -from enum import Enum from dlt.common.configuration.specs import ( CredentialsConfiguration, @@ -13,7 +12,6 @@ from dlt.common.storages.fsspec_filesystem import ( AZURE_BLOB_STORAGE_PROTOCOLS, S3_PROTOCOLS, - GCS_PROTOCOLS, ) from dlt.common.typing import TLoaderFileFormat from dlt.destinations.exceptions import LoadJobTerminalException @@ -100,6 +98,7 @@ def gen_copy_sql( loader_file_format: Format of the source file (jsonl, parquet, csv) is_case_sensitive: Whether column matching should be case-sensitive stage_name: Optional name of a predefined Snowflake stage + stage_bucket_url: Optional URL of the bucket containing the file local_stage_file_path: Path to use for local files staging_credentials: Optional credentials for accessing cloud storage csv_format: Optional configuration for CSV format @@ -108,7 +107,7 @@ def gen_copy_sql( A SQL string containing the COPY command """ # Determine file location type and get potentially modified URL - scheme = urlparse(file_url).scheme + parsed_file_url = urlparse(file_url) # Initialize clause components from_clause = "" @@ -117,19 +116,24 @@ def gen_copy_sql( on_error_clause = "" # Handle different file location types - if scheme == "file" or FilesystemConfiguration.is_local_path(file_url): + if parsed_file_url.scheme == "file" or FilesystemConfiguration.is_local_path(file_url): from_clause = f"FROM {local_stage_file_path}" elif stage_name: - if not stage_bucket_url: - raise LoadJobTerminalException( - file_url, f"Cannot load from stage {stage_name} without a stage bucket URL." - ) - relative_url = file_url.removeprefix(stage_bucket_url) + relative_url = parsed_file_url.path.lstrip("/") + + # If stage bucket URL has a path, remove it from the beginning of the relative URL because this is already specified + # in the Snowflake stage location. + if stage_bucket_url: + parsed_bucket_url = urlparse(stage_bucket_url) + stage_bucket_path = parsed_bucket_url.path.lstrip("/") + if stage_bucket_path: + relative_url = relative_url.removeprefix(stage_bucket_path) + from_clause = f"FROM @{stage_name}" files_clause = f"FILES = ('{relative_url}')" - elif scheme in S3_PROTOCOLS: + elif parsed_file_url.scheme in S3_PROTOCOLS: if staging_credentials and isinstance(staging_credentials, AwsCredentialsWithoutDefaults): credentials_clause = ( f"CREDENTIALS=(AWS_KEY_ID='{staging_credentials.aws_access_key_id}' " @@ -142,7 +146,7 @@ def gen_copy_sql( f"Cannot load from S3 path {file_url} without either credentials or a stage name.", ) - elif scheme in AZURE_BLOB_STORAGE_PROTOCOLS: + elif parsed_file_url.scheme in AZURE_BLOB_STORAGE_PROTOCOLS: if staging_credentials and isinstance(staging_credentials, AzureCredentialsWithoutDefaults): credentials_clause = ( f"CREDENTIALS=(AZURE_SAS_TOKEN='?{staging_credentials.azure_storage_sas_token}')" diff --git a/tests/load/snowflake/conftest.py b/tests/load/snowflake/conftest.py index fa524b4c52..19e3faa1a9 100644 --- a/tests/load/snowflake/conftest.py +++ b/tests/load/snowflake/conftest.py @@ -19,6 +19,12 @@ def stage_bucket_url(): return "s3://test-stage-bucket/" +@pytest.fixture +def stage_bucket_url_with_prefix(): + """Fixture for stage bucket URL.""" + return "s3://test-stage-bucket/with/prefix" + + @pytest.fixture def local_file_path(): """Fixture for local file path.""" diff --git a/tests/load/snowflake/test_snowflake_utils.py b/tests/load/snowflake/test_snowflake_utils.py index a0484d0817..e82461a7b8 100644 --- a/tests/load/snowflake/test_snowflake_utils.py +++ b/tests/load/snowflake/test_snowflake_utils.py @@ -134,21 +134,55 @@ def test_gen_copy_sql_with_stage(test_table, stage_name, stage_bucket_url): ) -def test_gen_copy_sql_stage_without_bucket_url(test_table, stage_name): - """Test that using a stage without bucket URL raises an error.""" - file_url = "s3://bucket/path/to/file.parquet" +def test_gen_copy_sql_with_stage_with_prefix_no_slash( + test_table, stage_name, stage_bucket_url_with_prefix +): + """Test generating COPY command with a named stage and bucket url without a forward slash.""" + file_url = f"{stage_bucket_url_with_prefix}path/to/file.parquet" - with pytest.raises(LoadJobTerminalException) as excinfo: - gen_copy_sql( - file_url=file_url, - qualified_table_name=test_table, - loader_file_format="parquet", - is_case_sensitive=True, - stage_name=stage_name, - ) + sql = gen_copy_sql( + file_url=file_url, + qualified_table_name=test_table, + loader_file_format="parquet", + is_case_sensitive=False, + stage_name=stage_name, + stage_bucket_url=stage_bucket_url_with_prefix, + ) - assert "Cannot load from stage" in str(excinfo.value) - assert "without a stage bucket URL" in str(excinfo.value) + assert_sql_contains( + sql, + f"COPY INTO {test_table}", + f"FROM @{stage_name}", + "FILES = ('path/to/file.parquet')", + "FILE_FORMAT = (TYPE = 'PARQUET'", + "MATCH_BY_COLUMN_NAME='CASE_INSENSITIVE'", + ) + + +def test_gen_copy_sql_with_stage_with_prefix_slash( + test_table, stage_name, stage_bucket_url_with_prefix +): + """Test generating COPY command with a named stage abd bucket url with a forward slash.""" + stage_bucket_url_with_prefix_slash = f"{stage_bucket_url_with_prefix}/" + file_url = f"{stage_bucket_url_with_prefix_slash}path/to/file.parquet" + + sql = gen_copy_sql( + file_url=file_url, + qualified_table_name=test_table, + loader_file_format="parquet", + is_case_sensitive=False, + stage_name=stage_name, + stage_bucket_url=stage_bucket_url_with_prefix_slash, + ) + + assert_sql_contains( + sql, + f"COPY INTO {test_table}", + f"FROM @{stage_name}", + "FILES = ('path/to/file.parquet')", + "FILE_FORMAT = (TYPE = 'PARQUET'", + "MATCH_BY_COLUMN_NAME='CASE_INSENSITIVE'", + ) def test_gen_copy_sql_s3_with_credentials(test_table, aws_credentials):