Skip to content

Commit

Permalink
Adding type annotations to fixtures and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
stevenayers committed Mar 10, 2025
1 parent 0466a7c commit 0029f1e
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 36 deletions.
4 changes: 2 additions & 2 deletions tests/load/pipeline/test_snowflake_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,11 @@ def test_snowflake_use_vectorized_scanner(
) -> None:
"""Tests whether the vectorized scanner option is correctly applied when loading Parquet files into Snowflake."""

from dlt.destinations.impl.snowflake.utils import gen_copy_sql
from dlt.destinations.impl.snowflake import utils

os.environ["DESTINATION__SNOWFLAKE__USE_VECTORIZED_SCANNER"] = use_vectorized_scanner

load_job_spy = mocker.spy(gen_copy_sql)
load_job_spy = mocker.spy(utils, "gen_copy_sql")

data_types = deepcopy(TABLE_ROW_ALL_DATA_TYPES_DATETIMES)
column_schemas = deepcopy(TABLE_UPDATE_COLUMNS_SCHEMA)
Expand Down
18 changes: 9 additions & 9 deletions tests/load/snowflake/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,51 +8,51 @@


@pytest.fixture
def test_table():
def test_table() -> str:
"""Fixture for test table name."""
return "test_schema.test_table"


@pytest.fixture
def stage_bucket_url():
def stage_bucket_url() -> str:
"""Fixture for stage bucket URL."""
return "s3://test-stage-bucket/"


@pytest.fixture
def stage_bucket_url_with_prefix():
def stage_bucket_url_with_prefix() -> str:
"""Fixture for stage bucket URL."""
return "s3://test-stage-bucket/with/prefix"


@pytest.fixture
def local_file_path():
def local_file_path() -> str:
"""Fixture for local file path."""
return "/tmp/data.csv"


@pytest.fixture
def stage_name():
def stage_name() -> str:
"""Fixture for stage name."""
return "test_stage"


@pytest.fixture
def local_stage_path():
def local_stage_path() -> str:
"""Fixture for local stage file path."""
return "@%temp_stage/data.csv"


@pytest.fixture
def aws_credentials():
def aws_credentials() -> AwsCredentialsWithoutDefaults:
"""Fixture for AWS credentials."""
return AwsCredentialsWithoutDefaults(
aws_access_key_id="test_aws_key", aws_secret_access_key="test_aws_secret"
)


@pytest.fixture
def azure_credentials():
def azure_credentials() -> AzureCredentialsWithoutDefaults:
"""Fixture for Azure credentials."""
return AzureCredentialsWithoutDefaults(
azure_storage_account_name="teststorage",
Expand All @@ -62,7 +62,7 @@ def azure_credentials():


@pytest.fixture
def default_csv_format():
def default_csv_format() -> CsvFormatConfiguration:
"""Fixture for default CSV format."""
return CsvFormatConfiguration(
include_header=True, delimiter=",", encoding="UTF-8", on_error_continue=False
Expand Down
76 changes: 51 additions & 25 deletions tests/load/snowflake/test_snowflake_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import pytest
from urllib.parse import urlparse

from dlt.common.configuration.specs import (
AzureCredentialsWithoutDefaults,
AwsCredentialsWithoutDefaults,
)
from dlt.common.data_writers.configuration import CsvFormatConfiguration
from dlt.common.exceptions import TerminalValueError
from dlt.destinations.exceptions import LoadJobTerminalException
Expand All @@ -9,19 +13,21 @@
ensure_snowflake_azure_url,
)

# mark all tests as essential, do not remove
pytestmark = pytest.mark.essential

# ----------------------------------
# Helper Functions
# ----------------------------------


def assert_sql_contains(sql, *phrases):
def assert_sql_contains(sql: str, *phrases):
"""Assert that SQL contains all the given phrases."""
for phrase in phrases:
assert phrase in sql, f"Expected '{phrase}' in SQL, but not found:\n{sql}"


def assert_sql_not_contains(sql, *phrases):
def assert_sql_not_contains(sql: str, *phrases):
"""Assert that SQL does not contain any of the given phrases."""
for phrase in phrases:
assert phrase not in sql, f"Found unexpected '{phrase}' in SQL:\n{sql}"
Expand All @@ -32,7 +38,7 @@ def assert_sql_not_contains(sql, *phrases):
# ----------------------------------


def test_file_format_jsonl(test_table, local_file_path, local_stage_path):
def test_file_format_jsonl(test_table: str, local_file_path: str, local_stage_path: str):
"""Test JSONL format handling in gen_copy_sql."""
sql = gen_copy_sql(
file_url=local_file_path,
Expand All @@ -50,7 +56,7 @@ def test_file_format_jsonl(test_table, local_file_path, local_stage_path):
)


def test_file_format_parquet(test_table, local_file_path, local_stage_path):
def test_file_format_parquet(test_table: str, local_file_path: str, local_stage_path: str):
"""Test Parquet format handling in gen_copy_sql."""
sql = gen_copy_sql(
file_url=local_file_path,
Expand All @@ -69,7 +75,9 @@ def test_file_format_parquet(test_table, local_file_path, local_stage_path):
)


def test_file_format_parquet_vectorized(test_table, local_file_path, local_stage_path):
def test_file_format_parquet_vectorized(
test_table: str, local_file_path: str, local_stage_path: str
):
"""Test Parquet format with vectorized scanner in gen_copy_sql."""
sql = gen_copy_sql(
file_url=local_file_path,
Expand Down Expand Up @@ -101,13 +109,13 @@ def test_file_format_parquet_vectorized(test_table, local_file_path, local_stage
],
)
def test_file_format_csv(
include_header,
delimiter,
encoding,
on_error_continue,
test_table,
local_file_path,
local_stage_path,
include_header: bool,
delimiter: str,
encoding: str,
on_error_continue: bool,
test_table: str,
local_file_path: str,
local_stage_path: str,
):
"""Test CSV format handling in gen_copy_sql with various options."""
csv_config = CsvFormatConfiguration(
Expand Down Expand Up @@ -153,7 +161,7 @@ def test_file_format_csv(
# ----------------------------------


def test_gen_copy_sql_local_file(test_table, local_file_path, local_stage_path):
def test_gen_copy_sql_local_file(test_table: str, local_file_path: str, local_stage_path: str):
"""Test generating COPY command for local files."""
sql = gen_copy_sql(
file_url=local_file_path,
Expand All @@ -172,7 +180,7 @@ def test_gen_copy_sql_local_file(test_table, local_file_path, local_stage_path):
)


def test_gen_copy_sql_with_stage(test_table, stage_name, stage_bucket_url):
def test_gen_copy_sql_with_stage(test_table: str, stage_name: str, stage_bucket_url: str):
"""Test generating COPY command with a named stage."""
file_url = f"{stage_bucket_url}path/to/file.parquet"

Expand All @@ -196,7 +204,7 @@ def test_gen_copy_sql_with_stage(test_table, stage_name, stage_bucket_url):


def test_gen_copy_sql_with_stage_with_prefix_no_slash(
test_table, stage_name, stage_bucket_url_with_prefix
test_table: str, stage_name: str, stage_bucket_url_with_prefix: str
):
"""Test generating COPY command with a named stage and bucket url without a forward slash."""
file_url = f"{stage_bucket_url_with_prefix}path/to/file.parquet"
Expand All @@ -221,7 +229,7 @@ def test_gen_copy_sql_with_stage_with_prefix_no_slash(


def test_gen_copy_sql_with_stage_with_prefix_slash(
test_table, stage_name, stage_bucket_url_with_prefix
test_table: str, stage_name: str, stage_bucket_url_with_prefix: str
):
"""Test generating COPY command with a named stage abd bucket url with a forward slash."""
stage_bucket_url_with_prefix_slash = f"{stage_bucket_url_with_prefix}/"
Expand All @@ -246,7 +254,9 @@ def test_gen_copy_sql_with_stage_with_prefix_slash(
)


def test_gen_copy_sql_s3_with_credentials(test_table, aws_credentials):
def test_gen_copy_sql_s3_with_credentials(
test_table: str, aws_credentials: AwsCredentialsWithoutDefaults
):
"""Test generating COPY command for S3 with AWS credentials."""
s3_url = "s3://bucket/path/to/file.jsonl"

Expand All @@ -268,7 +278,7 @@ def test_gen_copy_sql_s3_with_credentials(test_table, aws_credentials):
)


def test_gen_copy_sql_s3_without_credentials_or_stage(test_table):
def test_gen_copy_sql_s3_without_credentials_or_stage(test_table: str):
"""Test that using S3 without credentials or stage raises an error."""
s3_url = "s3://bucket/path/to/file.jsonl"

Expand All @@ -284,7 +294,11 @@ def test_gen_copy_sql_s3_without_credentials_or_stage(test_table):
assert "without either credentials or a stage name" in str(excinfo.value)


def test_gen_copy_sql_azure_with_credentials(test_table, azure_credentials, default_csv_format):
def test_gen_copy_sql_azure_with_credentials(
test_table: str,
azure_credentials: AzureCredentialsWithoutDefaults,
default_csv_format: CsvFormatConfiguration,
):
"""Test generating COPY command for Azure Blob with credentials."""
azure_url = "azure://teststorage.blob.core.windows.net/container/file.csv"

Expand Down Expand Up @@ -320,7 +334,7 @@ def test_gen_copy_sql_azure_with_credentials(test_table, azure_credentials, defa
assert "container" in parsed_url.path, "Path doesn't contain container name"


def test_gen_copy_sql_azure_without_credentials_or_stage(test_table):
def test_gen_copy_sql_azure_without_credentials_or_stage(test_table: str):
"""Test that using Azure Blob without credentials or stage raises an error."""
azure_url = "azure://account.blob.core.windows.net/container/file.csv"

Expand All @@ -336,7 +350,7 @@ def test_gen_copy_sql_azure_without_credentials_or_stage(test_table):
assert "without either credentials or a stage name" in str(excinfo.value)


def test_gen_copy_sql_gcs_without_stage(test_table):
def test_gen_copy_sql_gcs_without_stage(test_table: str):
"""Test that using GCS without stage raises an error."""
gcs_url = "gs://bucket/path/to/file.jsonl"

Expand All @@ -356,7 +370,11 @@ def test_gen_copy_sql_gcs_without_stage(test_table):
"is_case_sensitive,expected_case", [(True, "CASE_SENSITIVE"), (False, "CASE_INSENSITIVE")]
)
def test_gen_copy_sql_case_sensitivity(
is_case_sensitive, expected_case, test_table, local_file_path, local_stage_path
is_case_sensitive: bool,
expected_case: str,
test_table: str,
local_file_path: str,
local_stage_path: str,
):
"""Test case sensitivity setting in COPY command."""
sql = gen_copy_sql(
Expand All @@ -374,7 +392,11 @@ def test_gen_copy_sql_case_sensitivity(
"on_error_continue,include_header", [(True, True), (False, True), (True, False), (False, False)]
)
def test_gen_copy_sql_csv_options(
on_error_continue, include_header, test_table, local_file_path, local_stage_path
on_error_continue: bool,
include_header: bool,
test_table: str,
local_file_path: str,
local_stage_path: str,
):
"""Test CSV options in COPY command."""
csv_format = CsvFormatConfiguration(
Expand Down Expand Up @@ -406,7 +428,9 @@ def test_gen_copy_sql_csv_options(
assert_sql_not_contains(sql, "MATCH_BY_COLUMN_NAME")


def test_full_workflow_s3_with_aws_credentials(test_table, aws_credentials):
def test_full_workflow_s3_with_aws_credentials(
test_table: str, aws_credentials: AwsCredentialsWithoutDefaults
):
"""Test the full workflow for S3 with AWS credentials."""
# This test verifies that all components work together correctly
s3_url = "s3://test-bucket/path/to/data.jsonl"
Expand All @@ -433,7 +457,9 @@ def test_full_workflow_s3_with_aws_credentials(test_table, aws_credentials):
)


def test_full_workflow_azure_with_credentials(test_table, azure_credentials):
def test_full_workflow_azure_with_credentials(
test_table: str, azure_credentials: AzureCredentialsWithoutDefaults
):
"""Test the full workflow for Azure Blob with credentials."""
# This test verifies that all components work together correctly
azure_url = "azure://teststorage.blob.core.windows.net/container/file.parquet"
Expand Down

0 comments on commit 0029f1e

Please sign in to comment.