Skip to content

Commit

Permalink
Only modify the relative URL if a bucket URL is specified
Browse files Browse the repository at this point in the history
  • Loading branch information
stevenayers committed Feb 27, 2025
1 parent 0b8ad87 commit 70fed71
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 25 deletions.
28 changes: 16 additions & 12 deletions dlt/destinations/impl/snowflake/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Literal, Optional, Dict, Any, Union, Tuple
from typing import Optional, Dict, Any
from urllib.parse import urlparse, urlunparse
from enum import Enum

from dlt.common.configuration.specs import (
CredentialsConfiguration,
Expand All @@ -13,7 +12,6 @@
from dlt.common.storages.fsspec_filesystem import (
AZURE_BLOB_STORAGE_PROTOCOLS,
S3_PROTOCOLS,
GCS_PROTOCOLS,
)
from dlt.common.typing import TLoaderFileFormat
from dlt.destinations.exceptions import LoadJobTerminalException
Expand Down Expand Up @@ -100,6 +98,7 @@ def gen_copy_sql(
loader_file_format: Format of the source file (jsonl, parquet, csv)
is_case_sensitive: Whether column matching should be case-sensitive
stage_name: Optional name of a predefined Snowflake stage
stage_bucket_url: Optional URL of the bucket containing the file
local_stage_file_path: Path to use for local files
staging_credentials: Optional credentials for accessing cloud storage
csv_format: Optional configuration for CSV format
Expand All @@ -108,7 +107,7 @@ def gen_copy_sql(
A SQL string containing the COPY command
"""
# Determine file location type and get potentially modified URL
scheme = urlparse(file_url).scheme
parsed_file_url = urlparse(file_url)

# Initialize clause components
from_clause = ""
Expand All @@ -117,19 +116,24 @@ def gen_copy_sql(
on_error_clause = ""

# Handle different file location types
if scheme == "file" or FilesystemConfiguration.is_local_path(file_url):
if parsed_file_url.scheme == "file" or FilesystemConfiguration.is_local_path(file_url):
from_clause = f"FROM {local_stage_file_path}"

elif stage_name:
if not stage_bucket_url:
raise LoadJobTerminalException(
file_url, f"Cannot load from stage {stage_name} without a stage bucket URL."
)
relative_url = file_url.removeprefix(stage_bucket_url)
relative_url = parsed_file_url.path.lstrip("/")

# If stage bucket URL has a path, remove it from the beginning of the relative URL because this is already specified
# in the Snowflake stage location.
if stage_bucket_url:
parsed_bucket_url = urlparse(stage_bucket_url)
stage_bucket_path = parsed_bucket_url.path.lstrip("/")
if stage_bucket_path:
relative_url = relative_url.removeprefix(stage_bucket_path)

from_clause = f"FROM @{stage_name}"
files_clause = f"FILES = ('{relative_url}')"

elif scheme in S3_PROTOCOLS:
elif parsed_file_url.scheme in S3_PROTOCOLS:
if staging_credentials and isinstance(staging_credentials, AwsCredentialsWithoutDefaults):
credentials_clause = (
f"CREDENTIALS=(AWS_KEY_ID='{staging_credentials.aws_access_key_id}' "
Expand All @@ -142,7 +146,7 @@ def gen_copy_sql(
f"Cannot load from S3 path {file_url} without either credentials or a stage name.",
)

elif scheme in AZURE_BLOB_STORAGE_PROTOCOLS:
elif parsed_file_url.scheme in AZURE_BLOB_STORAGE_PROTOCOLS:
if staging_credentials and isinstance(staging_credentials, AzureCredentialsWithoutDefaults):
credentials_clause = (
f"CREDENTIALS=(AZURE_SAS_TOKEN='?{staging_credentials.azure_storage_sas_token}')"
Expand Down
6 changes: 6 additions & 0 deletions tests/load/snowflake/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ def stage_bucket_url():
return "s3://test-stage-bucket/"


@pytest.fixture
def stage_bucket_url_with_prefix():
"""Fixture for stage bucket URL."""
return "s3://test-stage-bucket/with/prefix"


@pytest.fixture
def local_file_path():
"""Fixture for local file path."""
Expand Down
60 changes: 47 additions & 13 deletions tests/load/snowflake/test_snowflake_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,21 +134,55 @@ def test_gen_copy_sql_with_stage(test_table, stage_name, stage_bucket_url):
)


def test_gen_copy_sql_stage_without_bucket_url(test_table, stage_name):
"""Test that using a stage without bucket URL raises an error."""
file_url = "s3://bucket/path/to/file.parquet"
def test_gen_copy_sql_with_stage_with_prefix_no_slash(
test_table, stage_name, stage_bucket_url_with_prefix
):
"""Test generating COPY command with a named stage and bucket url without a forward slash."""
file_url = f"{stage_bucket_url_with_prefix}path/to/file.parquet"

with pytest.raises(LoadJobTerminalException) as excinfo:
gen_copy_sql(
file_url=file_url,
qualified_table_name=test_table,
loader_file_format="parquet",
is_case_sensitive=True,
stage_name=stage_name,
)
sql = gen_copy_sql(
file_url=file_url,
qualified_table_name=test_table,
loader_file_format="parquet",
is_case_sensitive=False,
stage_name=stage_name,
stage_bucket_url=stage_bucket_url_with_prefix,
)

assert "Cannot load from stage" in str(excinfo.value)
assert "without a stage bucket URL" in str(excinfo.value)
assert_sql_contains(
sql,
f"COPY INTO {test_table}",
f"FROM @{stage_name}",
"FILES = ('path/to/file.parquet')",
"FILE_FORMAT = (TYPE = 'PARQUET'",
"MATCH_BY_COLUMN_NAME='CASE_INSENSITIVE'",
)


def test_gen_copy_sql_with_stage_with_prefix_slash(
test_table, stage_name, stage_bucket_url_with_prefix
):
"""Test generating COPY command with a named stage abd bucket url with a forward slash."""
stage_bucket_url_with_prefix_slash = f"{stage_bucket_url_with_prefix}/"
file_url = f"{stage_bucket_url_with_prefix_slash}path/to/file.parquet"

sql = gen_copy_sql(
file_url=file_url,
qualified_table_name=test_table,
loader_file_format="parquet",
is_case_sensitive=False,
stage_name=stage_name,
stage_bucket_url=stage_bucket_url_with_prefix_slash,
)

assert_sql_contains(
sql,
f"COPY INTO {test_table}",
f"FROM @{stage_name}",
"FILES = ('path/to/file.parquet')",
"FILE_FORMAT = (TYPE = 'PARQUET'",
"MATCH_BY_COLUMN_NAME='CASE_INSENSITIVE'",
)


def test_gen_copy_sql_s3_with_credentials(test_table, aws_credentials):
Expand Down

0 comments on commit 70fed71

Please sign in to comment.