-
Notifications
You must be signed in to change notification settings - Fork 97
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(datasets): Improved compatibility, functionality and testing for…
… SnowflakeTableDataset (#881) Signed-off-by: tdhooghe <[email protected]>
- Loading branch information
Showing
6 changed files
with
540 additions
and
294 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,95 +1,99 @@ | ||
"""``AbstractDataset`` implementation to access Snowflake using Snowpark dataframes | ||
""" | ||
"""``AbstractDataset`` implementation to access Snowflake using Snowpark dataframes""" | ||
|
||
from __future__ import annotations | ||
|
||
import logging | ||
from typing import Any | ||
from typing import Any, cast | ||
|
||
import snowflake.snowpark as sp | ||
import pandas as pd | ||
from kedro.io.core import AbstractDataset, DatasetError | ||
from snowflake.snowpark import DataFrame, Session | ||
from snowflake.snowpark import context as sp_context | ||
from snowflake.snowpark import exceptions as sp_exceptions | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class SnowparkTableDataset(AbstractDataset): | ||
"""``SnowparkTableDataset`` loads and saves Snowpark dataframes. | ||
"""``SnowparkTableDataset`` loads and saves Snowpark DataFrames. | ||
As of Mar-2023, the snowpark connector only works with Python 3.8. | ||
As of October 2024, the Snowpark connector works with Python 3.9, 3.10, and 3.11. | ||
Python 3.12 is not supported yet. | ||
Example usage for the | ||
`YAML API <https://docs.kedro.org/en/stable/data/\ | ||
data_catalog_yaml_examples.html>`_: | ||
.. code-block:: yaml | ||
weather: | ||
type: kedro_datasets.snowflake.SnowparkTableDataset | ||
table_name: "weather_data" | ||
database: "meteorology" | ||
schema: "observations" | ||
credentials: db_credentials | ||
save_args: | ||
mode: overwrite | ||
column_order: name | ||
table_type: '' | ||
You can skip everything but "table_name" if the database and | ||
schema are provided via credentials. That way catalog entries can be shorter | ||
if, for example, all used Snowflake tables live in same database/schema. | ||
Values in the dataset definition take priority over those defined in credentials. | ||
weather: | ||
type: kedro_datasets.snowflake.SnowparkTableDataset | ||
table_name: "weather_data" | ||
database: "meteorology" | ||
schema: "observations" | ||
credentials: db_credentials | ||
save_args: | ||
mode: overwrite | ||
column_order: name | ||
table_type: '' | ||
You can skip everything but "table_name" if the database and schema are | ||
provided via credentials. This allows catalog entries to be shorter when | ||
all Snowflake tables are in the same database and schema. Values in the | ||
dataset definition take priority over those defined in credentials. | ||
Example: | ||
Credentials file provides all connection attributes, catalog entry | ||
"weather" reuses credentials parameters, "polygons" catalog entry reuses | ||
all credentials parameters except providing a different schema name. | ||
Second example of credentials file uses ``externalbrowser`` authentication. | ||
The credentials file provides all connection attributes. The catalog entry | ||
for "weather" reuses the credentials parameters, while the "polygons" catalog | ||
entry reuses all credentials parameters except for specifying a different | ||
schema. The second example demonstrates the use of ``externalbrowser`` authentication. | ||
catalog.yml | ||
catalog.yml: | ||
.. code-block:: yaml | ||
weather: | ||
type: kedro_datasets.snowflake.SnowparkTableDataset | ||
table_name: "weather_data" | ||
database: "meteorology" | ||
schema: "observations" | ||
credentials: snowflake_client | ||
save_args: | ||
mode: overwrite | ||
column_order: name | ||
table_type: '' | ||
polygons: | ||
type: kedro_datasets.snowflake.SnowparkTableDataset | ||
table_name: "geopolygons" | ||
credentials: snowflake_client | ||
schema: "geodata" | ||
credentials.yml | ||
weather: | ||
type: kedro_datasets.snowflake.SnowparkTableDataset | ||
table_name: "weather_data" | ||
database: "meteorology" | ||
schema: "observations" | ||
credentials: snowflake_client | ||
save_args: | ||
mode: overwrite | ||
column_order: name | ||
table_type: '' | ||
polygons: | ||
type: kedro_datasets.snowflake.SnowparkTableDataset | ||
table_name: "geopolygons" | ||
credentials: snowflake_client | ||
schema: "geodata" | ||
credentials.yml: | ||
.. code-block:: yaml | ||
snowflake_client: | ||
account: 'ab12345.eu-central-1' | ||
port: 443 | ||
warehouse: "datascience_wh" | ||
database: "detailed_data" | ||
schema: "observations" | ||
user: "service_account_abc" | ||
password: "supersecret" | ||
snowflake_client: | ||
account: 'ab12345.eu-central-1' | ||
port: 443 | ||
warehouse: "datascience_wh" | ||
database: "detailed_data" | ||
schema: "observations" | ||
user: "service_account_abc" | ||
password: "supersecret" | ||
credentials.yml (with externalbrowser authenticator) | ||
credentials.yml (with externalbrowser authentication): | ||
.. code-block:: yaml | ||
snowflake_client: | ||
account: 'ab12345.eu-central-1' | ||
port: 443 | ||
warehouse: "datascience_wh" | ||
database: "detailed_data" | ||
schema: "observations" | ||
user: "[email protected]" | ||
authenticator: "externalbrowser" | ||
snowflake_client: | ||
account: 'ab12345.eu-central-1' | ||
port: 443 | ||
warehouse: "datascience_wh" | ||
database: "detailed_data" | ||
schema: "observations" | ||
user: "[email protected]" | ||
authenticator: "externalbrowser" | ||
""" | ||
|
||
|
@@ -110,9 +114,11 @@ def __init__( # noqa: PLR0913 | |
load_args: dict[str, Any] | None = None, | ||
save_args: dict[str, Any] | None = None, | ||
credentials: dict[str, Any] | None = None, | ||
session: Session | None = None, | ||
metadata: dict[str, Any] | None = None, | ||
) -> None: | ||
"""Creates a new instance of ``SnowparkTableDataset``. | ||
""" | ||
Creates a new instance of ``SnowparkTableDataset``. | ||
Args: | ||
table_name: The table name to load or save data to. | ||
|
@@ -154,6 +160,7 @@ def __init__( # noqa: PLR0913 | |
"'schema' must be provided by credentials or dataset." | ||
) | ||
schema = credentials["schema"] | ||
|
||
# Handle default load and save arguments | ||
self._load_args = {**self.DEFAULT_LOAD_ARGS, **(load_args or {})} | ||
self._save_args = {**self.DEFAULT_SAVE_ARGS, **(save_args or {})} | ||
|
@@ -167,6 +174,7 @@ def __init__( # noqa: PLR0913 | |
{"database": self._database, "schema": self._schema} | ||
) | ||
self._connection_parameters = connection_parameters | ||
self._session = session | ||
|
||
self.metadata = metadata | ||
|
||
|
@@ -178,8 +186,9 @@ def _describe(self) -> dict[str, Any]: | |
} | ||
|
||
@staticmethod | ||
def _get_session(connection_parameters) -> sp.Session: | ||
"""Given a connection string, create singleton connection | ||
def _get_session(connection_parameters) -> Session: | ||
""" | ||
Given a connection string, create singleton connection | ||
to be used across all instances of `SnowparkTableDataset` that | ||
need to connect to the same source. | ||
connection_parameters is a dictionary of any values | ||
|
@@ -199,45 +208,96 @@ def _get_session(connection_parameters) -> sp.Session: | |
""" | ||
try: | ||
logger.debug("Trying to reuse active snowpark session...") | ||
session = sp.context.get_active_session() | ||
except sp.exceptions.SnowparkSessionException: | ||
session = sp_context.get_active_session() | ||
except sp_exceptions.SnowparkSessionException: | ||
logger.debug("No active snowpark session found. Creating...") | ||
session = sp.Session.builder.configs(connection_parameters).create() | ||
session = Session.builder.configs(connection_parameters).create() | ||
return session | ||
|
||
@property | ||
def _session(self) -> sp.Session: | ||
return self._get_session(self._connection_parameters) | ||
def session(self) -> Session: | ||
""" | ||
Retrieve or create a session. | ||
Returns: | ||
Session: The current session associated with the object. | ||
""" | ||
if not self._session: | ||
self._session = self._get_session(self._connection_parameters) | ||
return self._session | ||
|
||
def load(self) -> sp.DataFrame: | ||
table_name: list = [ | ||
self._database, | ||
self._schema, | ||
self._table_name, | ||
] | ||
def load(self) -> DataFrame: | ||
""" | ||
Load data from a specified database table. | ||
sp_df = self._session.table(".".join(table_name)) | ||
return sp_df | ||
Returns: | ||
DataFrame: The loaded data as a Snowpark DataFrame. | ||
""" | ||
if self._session is None: | ||
raise DatasetError( | ||
"No active session. Please initialise a Snowpark session before loading data." | ||
) | ||
return self._session.table(self._validate_and_get_table_name()) | ||
|
||
def save(self, data: pd.DataFrame | DataFrame) -> None: | ||
""" | ||
Check if the data is a Snowpark DataFrame or a Pandas DataFrame, | ||
convert it to a Snowpark DataFrame if needed, and save it to the specified table. | ||
def save(self, data: sp.DataFrame) -> None: | ||
table_name = [ | ||
self._database, | ||
self._schema, | ||
self._table_name, | ||
] | ||
Args: | ||
data (pd.DataFrame | DataFrame): The data to save. | ||
""" | ||
if self._session is None: | ||
raise DatasetError( | ||
"No active session. Please initialise a Snowpark session before loading data." | ||
) | ||
if isinstance(data, pd.DataFrame): | ||
snowpark_df = self._session.create_dataframe(data) | ||
elif isinstance(data, DataFrame): | ||
snowpark_df = data | ||
else: | ||
raise DatasetError( | ||
f"Data of type {type(data)} is not supported for saving." | ||
) | ||
|
||
data.write.save_as_table(table_name, **self._save_args) | ||
snowpark_df.write.save_as_table( | ||
self._validate_and_get_table_name(), **self._save_args | ||
) | ||
|
||
def _exists(self) -> bool: | ||
session = self._session | ||
query = "SELECT COUNT(*) FROM {database}.INFORMATION_SCHEMA.TABLES \ | ||
WHERE TABLE_SCHEMA = '{schema}' \ | ||
AND TABLE_NAME = '{table_name}'" | ||
rows = session.sql( | ||
query.format( | ||
database=self._database, | ||
schema=self._schema, | ||
table_name=self._table_name, | ||
""" | ||
Check if a specified table exists in the database. | ||
Returns: | ||
bool: True if the table exists, False otherwise. | ||
""" | ||
if self._session is None: | ||
raise DatasetError( | ||
"No active session. Please initialise a Snowpark session before loading data." | ||
) | ||
).collect() | ||
return rows[0][0] == 1 | ||
try: | ||
self._session.table( | ||
f"{self._database}.{self._schema}.{self._table_name}" | ||
).show() | ||
return True | ||
except Exception as e: | ||
logger.debug(f"Table {self._table_name} does not exist: {e}") | ||
return False | ||
|
||
def _validate_and_get_table_name(self) -> str: | ||
""" | ||
Validate that all parts of the table name are not None and join them into a string. | ||
Args: | ||
parts (list[str | None]): The list containing database, schema, and table name. | ||
Returns: | ||
str: The joined table name in the format 'database.schema.table'. | ||
Raises: | ||
ValueError: If any part of the table name is None. | ||
""" | ||
parts: list[str | None] = [self._database, self._schema, self._table_name] | ||
if any(part is None or part == "" for part in parts): | ||
raise DatasetError("Database, schema or table name cannot be None or empty") | ||
parts_str = cast(list[str], parts) # make linting happy | ||
return ".".join(parts_str) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.