Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
stevenayers committed Feb 26, 2025
1 parent 5a60808 commit 0b8ad87
Show file tree
Hide file tree
Showing 5 changed files with 678 additions and 173 deletions.
162 changes: 16 additions & 146 deletions dlt/destinations/impl/snowflake/snowflake.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from typing import Optional, Sequence, List, Dict, Set
from urllib.parse import urlparse, urlunparse
from typing import Optional, Sequence, List, Dict

from dlt.common import logger
from dlt.common.data_writers.configuration import CsvFormatConfiguration
from dlt.common.destination import DestinationCapabilitiesContext
from dlt.common.destination.client import (
HasFollowupJobs,
Expand All @@ -12,21 +10,14 @@
CredentialsConfiguration,
SupportsStagingDestination,
)
from dlt.common.configuration.specs import (
AwsCredentialsWithoutDefaults,
AzureCredentialsWithoutDefaults,
)
from dlt.common.schema.utils import get_columns_names_with_prop
from dlt.common.storages.configuration import FilesystemConfiguration, ensure_canonical_az_url
from dlt.common.storages.file_storage import FileStorage
from dlt.common.schema import TColumnSchema, Schema, TColumnHint
from dlt.common.schema.typing import TColumnType, TTableSchema

from dlt.common.storages.fsspec_filesystem import AZURE_BLOB_STORAGE_PROTOCOLS, S3_PROTOCOLS
from dlt.common.typing import TLoaderFileFormat
from dlt.common.utils import uniq_id
from dlt.destinations.impl.snowflake.utils import gen_copy_sql
from dlt.destinations.job_client_impl import SqlJobClientWithStagingDataset
from dlt.destinations.exceptions import LoadJobTerminalException

from dlt.destinations.impl.snowflake.configuration import SnowflakeClientConfiguration
from dlt.destinations.impl.snowflake.sql_client import SnowflakeSqlClient
Expand Down Expand Up @@ -76,15 +67,20 @@ def run(self) -> None:
)
stage_file_path = f'@{self._stage_name}/"{self._load_id}"/{file_name}'

copy_sql = self.gen_copy_sql(
file_url,
qualified_table_name,
file_format, # type: ignore[arg-type]
self._sql_client.capabilities.generates_case_sensitive_identifiers(),
self._stage_name,
stage_file_path,
self._staging_credentials,
self._config.csv_format,
stage_bucket_url = None
if self._config.staging_config and self._config.staging_config.bucket_url:
stage_bucket_url = self._config.staging_config.bucket_url

copy_sql = gen_copy_sql(
file_url=file_url,
qualified_table_name=qualified_table_name,
loader_file_format=file_format, # type: ignore[arg-type]
is_case_sensitive=self._sql_client.capabilities.generates_case_sensitive_identifiers(),
stage_name=self._stage_name,
stage_bucket_url=stage_bucket_url,
local_stage_file_path=stage_file_path,
staging_credentials=self._staging_credentials,
csv_format=self._config.csv_format,
)

with self._sql_client.begin_transaction():
Expand All @@ -98,132 +94,6 @@ def run(self) -> None:
if stage_file_path and not self._keep_staged_files:
self._sql_client.execute_sql(f"REMOVE {stage_file_path}")

@classmethod
def gen_copy_sql(
cls,
file_url: str,
qualified_table_name: str,
loader_file_format: TLoaderFileFormat,
is_case_sensitive: bool,
stage_name: Optional[str] = None,
local_stage_file_path: Optional[str] = None,
staging_credentials: Optional[CredentialsConfiguration] = None,
csv_format: Optional[CsvFormatConfiguration] = None,
) -> str:
parsed_file_url = urlparse(file_url)
# check if local filesystem (file scheme or just a local file in native form)
is_local = parsed_file_url.scheme == "file" or FilesystemConfiguration.is_local_path(
file_url
)
# file_name = FileStorage.get_file_name_from_file_path(file_url)

from_clause = ""
credentials_clause = ""
files_clause = ""
on_error_clause = ""

case_folding = "CASE_SENSITIVE" if is_case_sensitive else "CASE_INSENSITIVE"
column_match_clause = f"MATCH_BY_COLUMN_NAME='{case_folding}'"

if not is_local:
bucket_scheme = parsed_file_url.scheme
# referencing an external s3/azure stage does not require explicit AWS credentials
if bucket_scheme in AZURE_BLOB_STORAGE_PROTOCOLS + S3_PROTOCOLS and stage_name:
from_clause = f"FROM '@{stage_name}'"
files_clause = f"FILES = ('{parsed_file_url.path.lstrip('/')}')"
# referencing an staged files via a bucket URL requires explicit AWS credentials
elif (
bucket_scheme in S3_PROTOCOLS
and staging_credentials
and isinstance(staging_credentials, AwsCredentialsWithoutDefaults)
):
credentials_clause = f"""CREDENTIALS=(AWS_KEY_ID='{staging_credentials.aws_access_key_id}' AWS_SECRET_KEY='{staging_credentials.aws_secret_access_key}')"""
from_clause = f"FROM '{file_url}'"
elif (
bucket_scheme in AZURE_BLOB_STORAGE_PROTOCOLS
and staging_credentials
and isinstance(staging_credentials, AzureCredentialsWithoutDefaults)
):
credentials_clause = f"CREDENTIALS=(AZURE_SAS_TOKEN='?{staging_credentials.azure_storage_sas_token}')"
file_url = cls.ensure_snowflake_azure_url(
file_url,
staging_credentials.azure_storage_account_name,
staging_credentials.azure_account_host,
)
from_clause = f"FROM '{file_url}'"
else:
# ensure that gcs bucket path starts with gcs://, this is a requirement of snowflake
file_url = file_url.replace("gs://", "gcs://")
if not stage_name:
# when loading from bucket stage must be given
raise LoadJobTerminalException(
file_url,
f"Cannot load from bucket path {file_url} without a stage name. See"
" https://dlthub.com/docs/dlt-ecosystem/destinations/snowflake for"
" instructions on setting up the `stage_name`",
)
from_clause = f"FROM @{stage_name}/"
files_clause = f"FILES = ('{urlparse(file_url).path.lstrip('/')}')"
else:
from_clause = f"FROM {local_stage_file_path}"

# decide on source format, stage_file_path will either be a local file or a bucket path
if loader_file_format == "jsonl":
source_format = "( TYPE = 'JSON', BINARY_FORMAT = 'BASE64' )"
elif loader_file_format == "parquet":
source_format = (
"(TYPE = 'PARQUET', BINARY_AS_TEXT = FALSE, USE_LOGICAL_TYPE = TRUE)"
# TODO: USE_VECTORIZED_SCANNER inserts null strings into VARIANT JSON
# " USE_VECTORIZED_SCANNER = TRUE)"
)
elif loader_file_format == "csv":
# empty strings are NULL, no data is NULL, missing columns (ERROR_ON_COLUMN_COUNT_MISMATCH) are NULL
csv_format = csv_format or CsvFormatConfiguration()
source_format = (
"(TYPE = 'CSV', BINARY_FORMAT = 'UTF-8', PARSE_HEADER ="
f" {csv_format.include_header}, FIELD_OPTIONALLY_ENCLOSED_BY = '\"', NULL_IF ="
" (''), ERROR_ON_COLUMN_COUNT_MISMATCH = FALSE,"
f" FIELD_DELIMITER='{csv_format.delimiter}', ENCODING='{csv_format.encoding}')"
)
# disable column match if headers are not provided
if not csv_format.include_header:
column_match_clause = ""
if csv_format.on_error_continue:
on_error_clause = "ON_ERROR = CONTINUE"
else:
raise ValueError(f"{loader_file_format} not supported for Snowflake COPY command.")

return f"""COPY INTO {qualified_table_name}
{from_clause}
{files_clause}
{credentials_clause}
FILE_FORMAT = {source_format}
{column_match_clause}
{on_error_clause}
"""

@staticmethod
def ensure_snowflake_azure_url(
file_url: str, account_name: str = None, account_host: str = None
) -> str:
# Explicit azure credentials are needed to load from bucket without a named stage
if not account_host and account_name:
account_host = f"{account_name}.blob.core.windows.net"
# get canonical url first to convert it into snowflake form
canonical_url = ensure_canonical_az_url(
file_url,
"azure",
account_name,
account_host,
)
parsed_file_url = urlparse(canonical_url)
return urlunparse(
parsed_file_url._replace(
path=f"/{parsed_file_url.username}{parsed_file_url.path}",
netloc=parsed_file_url.hostname,
)
)


class SnowflakeClient(SqlJobClientWithStagingDataset, SupportsStagingDestination):
def __init__(
Expand Down
192 changes: 192 additions & 0 deletions dlt/destinations/impl/snowflake/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
from typing import Literal, Optional, Dict, Any, Union, Tuple
from urllib.parse import urlparse, urlunparse
from enum import Enum

from dlt.common.configuration.specs import (
CredentialsConfiguration,
AwsCredentialsWithoutDefaults,
AzureCredentialsWithoutDefaults,
)
from dlt.common.data_writers.configuration import CsvFormatConfiguration
from dlt.common.storages import FilesystemConfiguration
from dlt.common.storages.configuration import ensure_canonical_az_url
from dlt.common.storages.fsspec_filesystem import (
AZURE_BLOB_STORAGE_PROTOCOLS,
S3_PROTOCOLS,
GCS_PROTOCOLS,
)
from dlt.common.typing import TLoaderFileFormat
from dlt.destinations.exceptions import LoadJobTerminalException


def ensure_snowflake_azure_url(
file_url: str, account_name: str = None, account_host: str = None
) -> str:
# Explicit azure credentials are needed to load from bucket without a named stage
if not account_host and account_name:
account_host = f"{account_name}.blob.core.windows.net"
# get canonical url first to convert it into snowflake form
canonical_url = ensure_canonical_az_url(
file_url,
"azure",
account_name,
account_host,
)
parsed_file_url = urlparse(canonical_url)
return urlunparse(
parsed_file_url._replace(
path=f"/{parsed_file_url.username}{parsed_file_url.path}",
netloc=parsed_file_url.hostname,
)
)


def generate_file_format_clause(
loader_file_format: TLoaderFileFormat, csv_format: Optional[CsvFormatConfiguration] = None
) -> Dict[str, Any]:
"""
Generates the FILE_FORMAT clause for the COPY command.
Args:
loader_file_format: The format of the source file (jsonl, parquet, csv)
csv_format: Optional configuration for CSV format
Returns:
A dictionary containing file format params and column_match_enabled flag
"""
if loader_file_format == "jsonl":
return {
"format_clause": "( TYPE = 'JSON', BINARY_FORMAT = 'BASE64' )",
"column_match_enabled": True,
}
elif loader_file_format == "parquet":
return {
"format_clause": "(TYPE = 'PARQUET', BINARY_AS_TEXT = FALSE, USE_LOGICAL_TYPE = TRUE)",
"column_match_enabled": True,
}
elif loader_file_format == "csv":
csv_config = csv_format or CsvFormatConfiguration()
return {
"format_clause": (
"(TYPE = 'CSV', BINARY_FORMAT = 'UTF-8', PARSE_HEADER ="
f" {csv_config.include_header}, FIELD_OPTIONALLY_ENCLOSED_BY = '\"', NULL_IF ="
" (''), ERROR_ON_COLUMN_COUNT_MISMATCH = FALSE,"
f" FIELD_DELIMITER='{csv_config.delimiter}', ENCODING='{csv_config.encoding}')"
),
"column_match_enabled": csv_config.include_header,
"on_error_continue": csv_config.on_error_continue,
}
else:
raise ValueError(f"{loader_file_format} not supported for Snowflake COPY command.")


def gen_copy_sql(
file_url: str,
qualified_table_name: str,
loader_file_format: TLoaderFileFormat,
is_case_sensitive: bool,
stage_name: Optional[str] = None,
stage_bucket_url: Optional[str] = None,
local_stage_file_path: Optional[str] = None,
staging_credentials: Optional[CredentialsConfiguration] = None,
csv_format: Optional[CsvFormatConfiguration] = None,
) -> str:
"""
Generates a Snowflake COPY command to load data from a file.
Args:
file_url: URL of the file to load
qualified_table_name: Fully qualified name of the target table
loader_file_format: Format of the source file (jsonl, parquet, csv)
is_case_sensitive: Whether column matching should be case-sensitive
stage_name: Optional name of a predefined Snowflake stage
local_stage_file_path: Path to use for local files
staging_credentials: Optional credentials for accessing cloud storage
csv_format: Optional configuration for CSV format
Returns:
A SQL string containing the COPY command
"""
# Determine file location type and get potentially modified URL
scheme = urlparse(file_url).scheme

# Initialize clause components
from_clause = ""
credentials_clause = ""
files_clause = ""
on_error_clause = ""

# Handle different file location types
if scheme == "file" or FilesystemConfiguration.is_local_path(file_url):
from_clause = f"FROM {local_stage_file_path}"

elif stage_name:
if not stage_bucket_url:
raise LoadJobTerminalException(
file_url, f"Cannot load from stage {stage_name} without a stage bucket URL."
)
relative_url = file_url.removeprefix(stage_bucket_url)
from_clause = f"FROM @{stage_name}"
files_clause = f"FILES = ('{relative_url}')"

elif scheme in S3_PROTOCOLS:
if staging_credentials and isinstance(staging_credentials, AwsCredentialsWithoutDefaults):
credentials_clause = (
f"CREDENTIALS=(AWS_KEY_ID='{staging_credentials.aws_access_key_id}' "
f"AWS_SECRET_KEY='{staging_credentials.aws_secret_access_key}')"
)
from_clause = f"FROM '{file_url}'"
else:
raise LoadJobTerminalException(
file_url,
f"Cannot load from S3 path {file_url} without either credentials or a stage name.",
)

elif scheme in AZURE_BLOB_STORAGE_PROTOCOLS:
if staging_credentials and isinstance(staging_credentials, AzureCredentialsWithoutDefaults):
credentials_clause = (
f"CREDENTIALS=(AZURE_SAS_TOKEN='?{staging_credentials.azure_storage_sas_token}')"
)
converted_az_url = ensure_snowflake_azure_url(
file_url,
staging_credentials.azure_storage_account_name,
staging_credentials.azure_account_host,
)
from_clause = f"FROM '{converted_az_url}'"
else:
raise LoadJobTerminalException(
file_url,
f"Cannot load from Azure path {file_url} without either credentials or a stage"
" name.",
)
else:
raise LoadJobTerminalException(
file_url,
f"Cannot load from bucket path {file_url} without a stage name. See "
"https://dlthub.com/docs/dlt-ecosystem/destinations/snowflake for "
"instructions on setting up the `stage_name`",
)

# Generate file format clause
format_info = generate_file_format_clause(loader_file_format, csv_format)
source_format = format_info["format_clause"]

# Set up column matching
column_match_clause = ""
if format_info.get("column_match_enabled", True):
case_folding = "CASE_SENSITIVE" if is_case_sensitive else "CASE_INSENSITIVE"
column_match_clause = f"MATCH_BY_COLUMN_NAME='{case_folding}'"

# Set up error handling
if loader_file_format == "csv" and format_info.get("on_error_continue", False):
on_error_clause = "ON_ERROR = CONTINUE"

# Construct the final SQL statement
return f"""COPY INTO {qualified_table_name}
{from_clause}
{files_clause}
{credentials_clause}
FILE_FORMAT = {source_format}
{column_match_clause}
{on_error_clause}
"""
Loading

0 comments on commit 0b8ad87

Please sign in to comment.