Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #2342 - Snowflake S3 Stage #2354

Open
wants to merge 9 commits into
base: devel
Choose a base branch
from
165 changes: 17 additions & 148 deletions dlt/destinations/impl/snowflake/snowflake.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from typing import Optional, Sequence, List, Dict, Set
from urllib.parse import urlparse, urlunparse
from typing import Optional, Sequence, List, Dict

from dlt.common import logger
from dlt.common.data_writers.configuration import CsvFormatConfiguration
from dlt.common.destination import DestinationCapabilitiesContext
from dlt.common.destination.client import (
HasFollowupJobs,
Expand All @@ -12,21 +10,14 @@
CredentialsConfiguration,
SupportsStagingDestination,
)
from dlt.common.configuration.specs import (
AwsCredentialsWithoutDefaults,
AzureCredentialsWithoutDefaults,
)
from dlt.common.schema.utils import get_columns_names_with_prop
from dlt.common.storages.configuration import FilesystemConfiguration, ensure_canonical_az_url
from dlt.common.storages.file_storage import FileStorage
from dlt.common.schema import TColumnSchema, Schema, TColumnHint
from dlt.common.schema.typing import TColumnType, TTableSchema

from dlt.common.storages.fsspec_filesystem import AZURE_BLOB_STORAGE_PROTOCOLS, S3_PROTOCOLS
from dlt.common.typing import TLoaderFileFormat
from dlt.common.utils import uniq_id
from dlt.destinations.impl.snowflake.utils import gen_copy_sql
from dlt.destinations.job_client_impl import SqlJobClientWithStagingDataset
from dlt.destinations.exceptions import LoadJobTerminalException

from dlt.destinations.impl.snowflake.configuration import SnowflakeClientConfiguration
from dlt.destinations.impl.snowflake.sql_client import SnowflakeSqlClient
Expand Down Expand Up @@ -76,16 +67,21 @@ def run(self) -> None:
)
stage_file_path = f'@{self._stage_name}/"{self._load_id}"/{file_name}'

copy_sql = self.gen_copy_sql(
file_url,
qualified_table_name,
file_format, # type: ignore[arg-type]
self._sql_client.capabilities.generates_case_sensitive_identifiers(),
self._stage_name,
stage_file_path,
self._staging_credentials,
self._config.csv_format,
self._config.use_vectorized_scanner,
stage_bucket_url = None
if self._config.staging_config and self._config.staging_config.bucket_url:
stage_bucket_url = self._config.staging_config.bucket_url

copy_sql = gen_copy_sql(
file_url=file_url,
qualified_table_name=qualified_table_name,
loader_file_format=file_format, # type: ignore[arg-type]
is_case_sensitive=self._sql_client.capabilities.generates_case_sensitive_identifiers(),
stage_name=self._stage_name,
stage_bucket_url=stage_bucket_url,
local_stage_file_path=stage_file_path,
staging_credentials=self._staging_credentials,
csv_format=self._config.csv_format,
use_vectorized_scanner=self._config.use_vectorized_scanner,
)

with self._sql_client.begin_transaction():
Expand All @@ -99,133 +95,6 @@ def run(self) -> None:
if stage_file_path and not self._keep_staged_files:
self._sql_client.execute_sql(f"REMOVE {stage_file_path}")

@classmethod
def gen_copy_sql(
cls,
file_url: str,
qualified_table_name: str,
loader_file_format: TLoaderFileFormat,
is_case_sensitive: bool,
stage_name: Optional[str] = None,
local_stage_file_path: Optional[str] = None,
staging_credentials: Optional[CredentialsConfiguration] = None,
csv_format: Optional[CsvFormatConfiguration] = None,
use_vectorized_scanner: Optional[bool] = False,
) -> str:
parsed_file_url = urlparse(file_url)
# check if local filesystem (file scheme or just a local file in native form)
is_local = parsed_file_url.scheme == "file" or FilesystemConfiguration.is_local_path(
file_url
)
# file_name = FileStorage.get_file_name_from_file_path(file_url)

from_clause = ""
credentials_clause = ""
files_clause = ""
on_error_clause = ""

case_folding = "CASE_SENSITIVE" if is_case_sensitive else "CASE_INSENSITIVE"
column_match_clause = f"MATCH_BY_COLUMN_NAME='{case_folding}'"

if not is_local:
bucket_scheme = parsed_file_url.scheme
# referencing an external s3/azure stage does not require explicit AWS credentials
if bucket_scheme in AZURE_BLOB_STORAGE_PROTOCOLS + S3_PROTOCOLS and stage_name:
from_clause = f"FROM '@{stage_name}'"
files_clause = f"FILES = ('{parsed_file_url.path.lstrip('/')}')"
# referencing an staged files via a bucket URL requires explicit AWS credentials
elif (
bucket_scheme in S3_PROTOCOLS
and staging_credentials
and isinstance(staging_credentials, AwsCredentialsWithoutDefaults)
):
credentials_clause = f"""CREDENTIALS=(AWS_KEY_ID='{staging_credentials.aws_access_key_id}' AWS_SECRET_KEY='{staging_credentials.aws_secret_access_key}')"""
from_clause = f"FROM '{file_url}'"
elif (
bucket_scheme in AZURE_BLOB_STORAGE_PROTOCOLS
and staging_credentials
and isinstance(staging_credentials, AzureCredentialsWithoutDefaults)
):
credentials_clause = f"CREDENTIALS=(AZURE_SAS_TOKEN='?{staging_credentials.azure_storage_sas_token}')"
file_url = cls.ensure_snowflake_azure_url(
file_url,
staging_credentials.azure_storage_account_name,
staging_credentials.azure_account_host,
)
from_clause = f"FROM '{file_url}'"
else:
# ensure that gcs bucket path starts with gcs://, this is a requirement of snowflake
file_url = file_url.replace("gs://", "gcs://")
if not stage_name:
# when loading from bucket stage must be given
raise LoadJobTerminalException(
file_url,
f"Cannot load from bucket path {file_url} without a stage name. See"
" https://dlthub.com/docs/dlt-ecosystem/destinations/snowflake for"
" instructions on setting up the `stage_name`",
)
from_clause = f"FROM @{stage_name}/"
files_clause = f"FILES = ('{urlparse(file_url).path.lstrip('/')}')"
else:
from_clause = f"FROM {local_stage_file_path}"

# decide on source format, stage_file_path will either be a local file or a bucket path
if loader_file_format == "jsonl":
source_format = "( TYPE = 'JSON', BINARY_FORMAT = 'BASE64' )"
elif loader_file_format == "parquet":
source_format = "(TYPE = 'PARQUET', BINARY_AS_TEXT = FALSE, USE_LOGICAL_TYPE = TRUE"
if use_vectorized_scanner:
source_format += ", USE_VECTORIZED_SCANNER = TRUE"
on_error_clause = "ON_ERROR = ABORT_STATEMENT"
source_format += ")"
elif loader_file_format == "csv":
# empty strings are NULL, no data is NULL, missing columns (ERROR_ON_COLUMN_COUNT_MISMATCH) are NULL
csv_format = csv_format or CsvFormatConfiguration()
source_format = (
"(TYPE = 'CSV', BINARY_FORMAT = 'UTF-8', PARSE_HEADER ="
f" {csv_format.include_header}, FIELD_OPTIONALLY_ENCLOSED_BY = '\"', NULL_IF ="
" (''), ERROR_ON_COLUMN_COUNT_MISMATCH = FALSE,"
f" FIELD_DELIMITER='{csv_format.delimiter}', ENCODING='{csv_format.encoding}')"
)
# disable column match if headers are not provided
if not csv_format.include_header:
column_match_clause = ""
if csv_format.on_error_continue:
on_error_clause = "ON_ERROR = CONTINUE"
else:
raise ValueError(f"{loader_file_format} not supported for Snowflake COPY command.")

return f"""COPY INTO {qualified_table_name}
{from_clause}
{files_clause}
{credentials_clause}
FILE_FORMAT = {source_format}
{column_match_clause}
{on_error_clause}
"""

@staticmethod
def ensure_snowflake_azure_url(
file_url: str, account_name: str = None, account_host: str = None
) -> str:
# Explicit azure credentials are needed to load from bucket without a named stage
if not account_host and account_name:
account_host = f"{account_name}.blob.core.windows.net"
# get canonical url first to convert it into snowflake form
canonical_url = ensure_canonical_az_url(
file_url,
"azure",
account_name,
account_host,
)
parsed_file_url = urlparse(canonical_url)
return urlunparse(
parsed_file_url._replace(
path=f"/{parsed_file_url.username}{parsed_file_url.path}",
netloc=parsed_file_url.hostname,
)
)


class SnowflakeClient(SqlJobClientWithStagingDataset, SupportsStagingDestination):
def __init__(
Expand Down
195 changes: 195 additions & 0 deletions dlt/destinations/impl/snowflake/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
from typing import Optional
from urllib.parse import urlparse, urlunparse

from dlt.common.configuration.specs import (
CredentialsConfiguration,
AwsCredentialsWithoutDefaults,
AzureCredentialsWithoutDefaults,
)
from dlt.common.data_writers.configuration import CsvFormatConfiguration
from dlt.common.storages import FilesystemConfiguration
from dlt.common.storages.configuration import ensure_canonical_az_url
from dlt.common.storages.fsspec_filesystem import (
AZURE_BLOB_STORAGE_PROTOCOLS,
S3_PROTOCOLS,
)

from dlt.common.typing import TLoaderFileFormat
from dlt.destinations.exceptions import LoadJobTerminalException


def ensure_snowflake_azure_url(
file_url: str, account_name: str = None, account_host: str = None
) -> str:
# Explicit azure credentials are needed to load from bucket without a named stage
if not account_host and account_name:
account_host = f"{account_name}.blob.core.windows.net"
# get canonical url first to convert it into snowflake form
canonical_url = ensure_canonical_az_url(
file_url,
"azure",
account_name,
account_host,
)
parsed_file_url = urlparse(canonical_url)
return urlunparse(
parsed_file_url._replace(
path=f"/{parsed_file_url.username}{parsed_file_url.path}",
netloc=parsed_file_url.hostname,
)
)


def _build_format_clause(format_clause: list[str]) -> str:
joined = ", ".join(format_clause)
return f"({joined})"


def gen_copy_sql(
file_url: str,
qualified_table_name: str,
loader_file_format: TLoaderFileFormat,
is_case_sensitive: bool,
stage_name: Optional[str] = None,
stage_bucket_url: Optional[str] = None,
local_stage_file_path: Optional[str] = None,
staging_credentials: Optional[CredentialsConfiguration] = None,
csv_format: Optional[CsvFormatConfiguration] = None,
use_vectorized_scanner: Optional[bool] = False,
) -> str:
"""
Generates a Snowflake COPY command to load data from a file.

Args:
use_vectorized_scanner: Whether to use the vectorized scanner in COPY INTO
file_url: URL of the file to load
qualified_table_name: Fully qualified name of the target table
loader_file_format: Format of the source file (jsonl, parquet, csv)
is_case_sensitive: Whether column matching should be case-sensitive
stage_name: Optional name of a predefined Snowflake stage
stage_bucket_url: Optional URL of the bucket containing the file
local_stage_file_path: Path to use for local files
staging_credentials: Optional credentials for accessing cloud storage
csv_format: Optional configuration for CSV format

Returns:
A SQL string containing the COPY command
"""
# Determine file location type and get potentially modified URL
parsed_file_url = urlparse(file_url)

# Initialize clause components
credentials_clause = ""
files_clause = ""

# Handle different file location types
if parsed_file_url.scheme == "file" or FilesystemConfiguration.is_local_path(file_url):
from_clause = f"FROM {local_stage_file_path}"

elif stage_name:
relative_url = parsed_file_url.path.lstrip("/")

# If stage bucket URL has a path, remove it from the beginning of the relative URL because this is already specified
# in the Snowflake stage location.
if stage_bucket_url:
parsed_bucket_url = urlparse(stage_bucket_url)
stage_bucket_path = parsed_bucket_url.path.lstrip("/")
if stage_bucket_path:
relative_url = relative_url.removeprefix(stage_bucket_path)

from_clause = f"FROM @{stage_name}"
files_clause = f"FILES = ('{relative_url}')"

elif parsed_file_url.scheme in S3_PROTOCOLS:
if staging_credentials and isinstance(staging_credentials, AwsCredentialsWithoutDefaults):
credentials_clause = (
f"CREDENTIALS=(AWS_KEY_ID='{staging_credentials.aws_access_key_id}' "
f"AWS_SECRET_KEY='{staging_credentials.aws_secret_access_key}')"
)
from_clause = f"FROM '{file_url}'"
else:
raise LoadJobTerminalException(
file_url,
f"Cannot load from S3 path {file_url} without either credentials or a stage name.",
)

elif parsed_file_url.scheme in AZURE_BLOB_STORAGE_PROTOCOLS:
if staging_credentials and isinstance(staging_credentials, AzureCredentialsWithoutDefaults):
credentials_clause = (
f"CREDENTIALS=(AZURE_SAS_TOKEN='?{staging_credentials.azure_storage_sas_token}')"
)
converted_az_url = ensure_snowflake_azure_url(
file_url,
staging_credentials.azure_storage_account_name,
staging_credentials.azure_account_host,
)
from_clause = f"FROM '{converted_az_url}'"
else:
raise LoadJobTerminalException(
file_url,
f"Cannot load from Azure path {file_url} without either credentials or a stage"
" name.",
)
else:
raise LoadJobTerminalException(
file_url,
f"Cannot load from bucket path {file_url} without a stage name. See "
"https://dlthub.com/docs/dlt-ecosystem/destinations/snowflake for "
"instructions on setting up the `stage_name`",
)

# Generate file format clause
if is_case_sensitive:
column_match_clause = "MATCH_BY_COLUMN_NAME='CASE_SENSITIVE'"
else:
column_match_clause = "MATCH_BY_COLUMN_NAME='CASE_INSENSITIVE'"

if loader_file_format == "jsonl":
on_error_clause = ""
format_opts = [
"TYPE = 'JSON'",
"BINARY_FORMAT = 'BASE64'",
]
elif loader_file_format == "parquet":
format_opts = [
"TYPE = 'PARQUET'",
"BINARY_AS_TEXT = FALSE",
"USE_LOGICAL_TYPE = TRUE",
]
if use_vectorized_scanner:
format_opts.append("USE_VECTORIZED_SCANNER = TRUE")
on_error_clause = "ON_ERROR = ABORT_STATEMENT"
else:
on_error_clause = ""
elif loader_file_format == "csv":
csv_config = csv_format or CsvFormatConfiguration()
if not csv_config.include_header:
column_match_clause = ""
if csv_config.on_error_continue:
on_error_clause = "ON_ERROR = CONTINUE"
else:
on_error_clause = ""
format_opts = [
"TYPE = 'CSV'",
"BINARY_FORMAT = 'UTF-8'",
f"PARSE_HEADER = {csv_config.include_header}",
"FIELD_OPTIONALLY_ENCLOSED_BY = '\"'",
"NULL_IF = ('')",
"ERROR_ON_COLUMN_COUNT_MISMATCH = FALSE",
f"FIELD_DELIMITER = '{csv_config.delimiter}'",
f"ENCODING = '{csv_config.encoding}'",
]
else:
raise ValueError(f"{loader_file_format} not supported for Snowflake COPY command.")

source_format = _build_format_clause(format_opts)

# Construct the final SQL statement
return f"""COPY INTO {qualified_table_name}
{from_clause}
{files_clause}
{credentials_clause}
FILE_FORMAT = {source_format}
{column_match_clause}
{on_error_clause}
"""
Loading