diff --git a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py new file mode 100644 index 000000000..4cbb2ea37 --- /dev/null +++ b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py @@ -0,0 +1,517 @@ +"""``BaseTableDataset`` implementation used to add the base for +``ManagedTableDataset`` and ``ExternalTableDataset``. +""" +from __future__ import annotations + +import logging +import re +from dataclasses import dataclass, field +from typing import Any, ClassVar + +import pandas as pd +from kedro.io.core import ( + AbstractVersionedDataset, + DatasetError, + Version, + VersionNotFoundError, +) +from pyspark.sql import DataFrame +from pyspark.sql.types import StructType +from pyspark.sql.utils import AnalysisException, ParseException + +from kedro_datasets.spark.spark_dataset import _get_spark + +logger = logging.getLogger(__name__) +pd.DataFrame.iteritems = pd.DataFrame.items + + +@dataclass(frozen=True) +class BaseTable: + """Stores the definition of a base table. + + Acts as the base class for `ManagedTable` and `ExternalTable`. + """ + + # regex for tables, catalogs and schemas + _NAMING_REGEX: ClassVar[str] = r"\b[0-9a-zA-Z_-]{1,}\b" + _VALID_WRITE_MODES: ClassVar[list[str]] = field( + default=["overwrite", "upsert", "append"] + ) + _VALID_DATAFRAME_TYPES: ClassVar[list[str]] = field(default=["spark", "pandas"]) + _VALID_FORMATS: ClassVar[list[str]] = field( + default=["delta", "parquet", "csv", "json", "orc", "avro", "text"] + ) + + database: str + catalog: str | None + table: str + write_mode: str | None + location: str | None + dataframe_type: str + primary_key: str | list[str] | None + owner_group: str | None + partition_columns: str | list[str] | None + format: str = "delta" + json_schema: dict[str, Any] | None = None + + def __post_init__(self): + """Run validation methods if declared. + + The validation method can be a simple check + that raises DatasetError. + + The validation is performed by calling a function with the signature + `validate_(self, value) -> raises DatasetError`. + """ + for name in self.__dataclass_fields__.keys(): + method = getattr(self, f"_validate_{name}", None) + if method: + method() + + def _validate_format(self): + """Validates the format of the table. + + Raises: + DatasetError: If an invalid `format` is passed. + """ + if self.format not in self._VALID_FORMATS: + valid_formats = ", ".join(self._VALID_FORMATS) + raise DatasetError( + f"Invalid `format` provided: {self.format}. " + f"`format` must be one of: {valid_formats}" + ) + + def _validate_table(self): + """Validates table name. + + Raises: + DatasetError: If the table name does not conform to naming constraints. + """ + if not re.fullmatch(self._NAMING_REGEX, self.table): + raise DatasetError("Table does not conform to naming") + + def _validate_database(self): + """Validates database name. + + Raises: + DatasetError: If the dataset name does not conform to naming constraints. + """ + if not re.fullmatch(self._NAMING_REGEX, self.database): + raise DatasetError("Database does not conform to naming") + + def _validate_catalog(self): + """Validates catalog name. + + Raises: + DatasetError: If the catalog name does not conform to naming constraints. + """ + if self.catalog: + if not re.fullmatch(self._NAMING_REGEX, self.catalog): + raise DatasetError("Catalog does not conform to naming") + + def _validate_write_mode(self): + """Validates the write mode. + + Raises: + DatasetError: If an invalid `write_mode` is passed. + """ + if ( + self.write_mode is not None + and self.write_mode not in self._VALID_WRITE_MODES + ): + valid_modes = ", ".join(self._VALID_WRITE_MODES) + raise DatasetError( + f"Invalid `write_mode` provided: {self.write_mode}. " + f"`write_mode` must be one of: {valid_modes}" + ) + + def _validate_dataframe_type(self): + """Validates the dataframe type. + + Raises: + DatasetError: If an invalid `dataframe_type` is passed. + """ + if self.dataframe_type not in self._VALID_DATAFRAME_TYPES: + valid_types = ", ".join(self._VALID_DATAFRAME_TYPES) + raise DatasetError(f"`dataframe_type` must be one of {valid_types}") + + def _validate_primary_key(self): + """Validates the primary key of the table. + + Raises: + DatasetError: If no `primary_key` is specified. + """ + if self.primary_key is None or len(self.primary_key) == 0: + if self.write_mode == "upsert": + raise DatasetError( + f"`primary_key` must be provided for" + f"`write_mode` {self.write_mode}" + ) + + def full_table_location(self) -> str | None: + """Returns the full table location. + + Returns: + str | None : Table location in the format catalog.database.table or None if database and table aren't defined. + """ + full_table_location = None + if self.catalog and self.database and self.table: + full_table_location = f"`{self.catalog}`.`{self.database}`.`{self.table}`" + elif self.database and self.table: + full_table_location = f"`{self.database}`.`{self.table}`" + return full_table_location + + def schema(self) -> StructType | None: + """Returns the Spark schema of the table if it exists. + + Returns: + StructType: The schema of the table. + """ + schema = None + try: + if self.json_schema is not None: + schema = StructType.fromJson(self.json_schema) + except (KeyError, ValueError) as exc: + raise DatasetError(exc) from exc + return schema + + def exists(self) -> bool: + """Checks to see if the table exists. + + Returns: + bool: Boolean of whether the table exists in the Spark session. + """ + if self.catalog: + try: + _get_spark().sql(f"USE CATALOG `{self.catalog}`") + except (ParseException, AnalysisException) as exc: + logger.warning( + "catalog %s not found or unity not enabled. Error message: %s", + self.catalog, + exc, + ) + try: + return ( + _get_spark() + .sql(f"SHOW TABLES IN `{self.database}`") + .filter(f"tableName = '{self.table}'") + .count() + > 0 + ) + except (ParseException, AnalysisException) as exc: + logger.warning("error occured while trying to find table: %s", exc) + return False + + +class BaseTableDataset(AbstractVersionedDataset): + """``BaseTableDataset`` loads and saves data into managed delta tables or external tables on Databricks. + Load and save can be in Spark or Pandas dataframes, specified in dataframe_type. + + This dataset is not meant to be used directly. It is a base class for ``ManagedTableDataset`` and ``ExternalTableDataset``. + """ + + # datasets that inherit from this class cannot be used with ``ParallelRunner``, + # therefore it has the attribute ``_SINGLE_PROCESS = True`` + # for parallelism within a Spark pipeline please consider + # using ``ThreadRunner`` instead. + _SINGLE_PROCESS = True + + def __init__( # noqa: PLR0913 + self, + *, + table: str, + catalog: str | None = None, + database: str = "default", + format: str = "delta", + write_mode: str | None = None, + location: str | None = None, + dataframe_type: str = "spark", + primary_key: str | list[str] | None = None, + version: Version | None = None, + # the following parameters are used by project hooks + # to create or update table properties + schema: dict[str, Any] | None = None, + partition_columns: list[str] | None = None, + owner_group: str | None = None, + metadata: dict[str, Any] | None = None, + ) -> None: + """Creates a new instance of ``BaseTableDataset``. + + Args: + table: The name of the table. + catalog: The name of the catalog in Unity. + Defaults to None. + database: The name of the database. + (also referred to as schema). Defaults to "default". + format: The format of the table. + Applicable only for external tables. + Defaults to "delta". + write_mode: The mode to write the data into the table. If not + present, the data set is read-only. + Options are:["overwrite", "append", "upsert"]. + "upsert" mode requires primary_key field to be populated. + Defaults to None. + location: The location of the table. + Applicable only for external tables. + Should be a valid path in an external location that has already been created. + Defaults to None. + dataframe_type: "pandas" or "spark" dataframe. + Defaults to "spark". + primary_key: The primary key of the table. + Can be in the form of a list. Defaults to None. + version: kedro.io.core.Version instance to load the data. + Defaults to None. + schema: The schema of the table in JSON form. + Dataframes will be truncated to match the schema if provided. + Used by the hooks to create the table if the schema is provided. + Defaults to None. + partition_columns: The columns to use for partitioning the table. + Used by the hooks. Defaults to None. + owner_group: If table access control is enabled in your workspace, + specifying owner_group will transfer ownership of the table and database to + this owner. All databases should have the same owner_group. Defaults to None. + metadata: Any arbitrary metadata. + This is ignored by Kedro, but may be consumed by users or external plugins. + Raises: + DatasetError: Invalid configuration supplied (through ``BaseTable`` validation). + """ + self._table = self._create_table( + table=table, + catalog=catalog, + database=database, + format=format, + write_mode=write_mode, + location=location, + dataframe_type=dataframe_type, + primary_key=primary_key, + json_schema=schema, + partition_columns=partition_columns, + owner_group=owner_group, + ) + + self.metadata = metadata + self._version = version + + super().__init__( + filepath=None, # type: ignore[arg-type] + version=version, + exists_function=self._exists, # type: ignore[arg-type] + ) + + def _create_table( # noqa: PLR0913 + self, + table: str, + catalog: str | None, + database: str, + format: str, + write_mode: str | None, + location: str | None, + dataframe_type: str, + primary_key: str | list[str] | None, + json_schema: dict[str, Any] | None, + partition_columns: list[str] | None, + owner_group: str | None, + ) -> BaseTable: + """Creates a ``BaseTable`` instance with the provided attributes. + + Args: + table: The name of the table. + catalog: The catalog of the table. + database: The database of the table. + format: The format of the table. + location: The location of the table. + write_mode: The write mode for the table. + dataframe_type: The type of dataframe. + primary_key: The primary key of the table. + json_schema: The JSON schema of the table. + partition_columns: The partition columns of the table. + owner_group: The owner group of the table. + + Returns: + ``BaseTable``: The new ``BaseTable`` instance. + """ + return BaseTable( + table=table, + catalog=catalog, + database=database, + format=format, + write_mode=write_mode, + location=location, + dataframe_type=dataframe_type, + json_schema=json_schema, + partition_columns=partition_columns, + owner_group=owner_group, + primary_key=primary_key, + ) + + def _load(self) -> DataFrame | pd.DataFrame: + """Loads the version of data in the format defined in the init + (spark|pandas dataframe). + + Raises: + VersionNotFoundError: If the version defined in + the init doesn't exist. + + Returns: + Union[DataFrame, pd.DataFrame]: Returns a dataframe + in the format defined in the init. + """ + if self._version and self._version.load >= 0: + try: + data = ( + _get_spark() + .read.format("delta") + .option("versionAsOf", self._version.load) + .table(self._table.full_table_location()) + ) + except Exception as exc: + raise VersionNotFoundError(self._version.load) from exc + else: + data = _get_spark().table(self._table.full_table_location()) + if self._table.dataframe_type == "pandas": + data = data.toPandas() + return data + + def _save(self, data: DataFrame | pd.DataFrame) -> None: + """Saves the data based on the write_mode and dataframe_type in the init. + If write_mode is pandas, Spark dataframe is created first. + If schema is provided, data is matched to schema before saving + (columns will be sorted and truncated). + + Args: + data (Any): Spark or pandas dataframe to save to the table location. + """ + if self._table.write_mode is None: + raise DatasetError( + "'save' can not be used in read-only mode. " + f"Change 'write_mode' value to {', '.join(self._table._VALID_WRITE_MODES)}" + ) + # filter columns specified in schema and match their ordering + schema = self._table.schema() + if schema: + cols = schema.fieldNames() + if self._table.dataframe_type == "pandas": + data = _get_spark().createDataFrame( + data.loc[:, cols], schema=self._table.schema() + ) + else: + data = data.select(*cols) + elif self._table.dataframe_type == "pandas": + data = _get_spark().createDataFrame(data) + + method = getattr(self, f"_save_{self._table.write_mode}", None) + if method: + method(data) + + def _save_append(self, data: DataFrame) -> None: + """Saves the data to the table by appending it + to the location defined in the init. + + Args: + data (DataFrame): The Spark dataframe to append to the table. + """ + writer = data.write.format(self._table.format).mode("append") + + if self._table.partition_columns: + writer.partitionBy( + *self._table.partition_columns + if isinstance(self._table.partition_columns, list) + else [self._table.partition_columns] + ) + + if self._table.location: + writer.option("path", self._table.location) + + writer.saveAsTable(self._table.full_table_location() or "") + + def _save_overwrite(self, data: DataFrame) -> None: + """Overwrites the data in the table with the data provided. + + Args: + data (DataFrame): The Spark dataframe to overwrite the table with. + """ + writer = ( + data.write.format(self._table.format) + .mode("overwrite") + .option("overwriteSchema", "true") + ) + + if self._table.partition_columns: + writer.partitionBy( + *self._table.partition_columns + if isinstance(self._table.partition_columns, list) + else [self._table.partition_columns] + ) + + if self._table.location: + writer.option("path", self._table.location) + + writer.saveAsTable(self._table.full_table_location() or "") + + def _save_upsert(self, update_data: DataFrame) -> None: + """Upserts the data by joining on primary_key columns or column. + If table doesn't exist at save, the data is inserted to a new table. + + Args: + update_data (DataFrame): The Spark dataframe to upsert. + """ + if self._exists(): + base_data = _get_spark().table(self._table.full_table_location()) + base_columns = base_data.columns + update_columns = update_data.columns + + if set(update_columns) != set(base_columns): + raise DatasetError( + f"Upsert requires tables to have identical columns. " + f"Delta table {self._table.full_table_location()} " + f"has columns: {base_columns}, whereas " + f"dataframe has columns {update_columns}" + ) + + where_expr = "" + if isinstance(self._table.primary_key, str): + where_expr = ( + f"base.{self._table.primary_key}=update.{self._table.primary_key}" + ) + elif isinstance(self._table.primary_key, list): + where_expr = " AND ".join( + f"base.{col}=update.{col}" for col in self._table.primary_key + ) + + update_data.createOrReplaceTempView("update") + _get_spark().conf.set("fullTableAddress", self._table.full_table_location()) + _get_spark().conf.set("whereExpr", where_expr) + upsert_sql = """MERGE INTO ${fullTableAddress} base USING update ON ${whereExpr} + WHEN MATCHED THEN UPDATE SET * WHEN NOT MATCHED THEN INSERT *""" + _get_spark().sql(upsert_sql) + else: + self._save_append(update_data) + + def _describe(self) -> dict[str, str | list | None]: + """Returns a description of the instance of the dataset. + + Returns: + Dict[str, str]: Dict with the details of the dataset. + """ + return { + "catalog": self._table.catalog, + "database": self._table.database, + "table": self._table.table, + "format": self._table.format, + "location": self._table.location, + "write_mode": self._table.write_mode, + "dataframe_type": self._table.dataframe_type, + "primary_key": self._table.primary_key, + "version": str(self._version), + "owner_group": self._table.owner_group, + "partition_columns": self._table.partition_columns, + } + + def _exists(self) -> bool: + """Checks to see if the table exists. + + Returns: + bool: Boolean of whether the table defined + in the dataset instance exists in the Spark session. + """ + return self._table.exists() diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index ecca89f80..cd759aae8 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -4,158 +4,29 @@ from __future__ import annotations import logging -import re -from dataclasses import dataclass -from typing import Any +from dataclasses import dataclass, field +from typing import Any, ClassVar import pandas as pd -from kedro.io.core import ( - AbstractVersionedDataset, - DatasetError, - Version, - VersionNotFoundError, -) -from pyspark.sql import DataFrame -from pyspark.sql.types import StructType -from pyspark.sql.utils import AnalysisException, ParseException +from kedro.io.core import Version -from kedro_datasets.spark.spark_dataset import _get_spark +from kedro_datasets.databricks._base_table_dataset import BaseTable, BaseTableDataset logger = logging.getLogger(__name__) pd.DataFrame.iteritems = pd.DataFrame.items @dataclass(frozen=True) -class ManagedTable: - """Stores the definition of a managed table""" +class ManagedTable(BaseTable): + """Stores the definition of a managed table.""" - # regex for tables, catalogs and schemas - _NAMING_REGEX = r"\b[0-9a-zA-Z_-]{1,}\b" - _VALID_WRITE_MODES = ["overwrite", "upsert", "append"] - _VALID_DATAFRAME_TYPES = ["spark", "pandas"] - database: str - catalog: str | None - table: str - write_mode: str | None - dataframe_type: str - primary_key: str | list[str] | None - owner_group: str | None - partition_columns: str | list[str] | None - json_schema: dict[str, Any] | None = None + _VALID_FORMATS: ClassVar[list[str]] = field(default=["delta"]) - def __post_init__(self): - """Run validation methods if declared. - The validation method can be a simple check - that raises DatasetError. - - The validation is performed by calling a function with the signature - `validate_(self, value) -> raises DatasetError`. - """ - for name in self.__dataclass_fields__.keys(): - method = getattr(self, f"_validate_{name}", None) - if method: - method() - - def _validate_table(self): - """Validates table name - - Raises: - DatasetError: If the table name does not conform to naming constraints. - """ - if not re.fullmatch(self._NAMING_REGEX, self.table): - raise DatasetError("table does not conform to naming") - - def _validate_database(self): - """Validates database name - - Raises: - DatasetError: If the dataset name does not conform to naming constraints. - """ - if not re.fullmatch(self._NAMING_REGEX, self.database): - raise DatasetError("database does not conform to naming") - - def _validate_catalog(self): - """Validates catalog name - - Raises: - DatasetError: If the catalog name does not conform to naming constraints. - """ - if self.catalog: - if not re.fullmatch(self._NAMING_REGEX, self.catalog): - raise DatasetError("catalog does not conform to naming") - - def _validate_write_mode(self): - """Validates the write mode - - Raises: - DatasetError: If an invalid `write_mode` is passed. - """ - if ( - self.write_mode is not None - and self.write_mode not in self._VALID_WRITE_MODES - ): - valid_modes = ", ".join(self._VALID_WRITE_MODES) - raise DatasetError( - f"Invalid `write_mode` provided: {self.write_mode}. " - f"`write_mode` must be one of: {valid_modes}" - ) - - def _validate_dataframe_type(self): - """Validates the dataframe type - - Raises: - DatasetError: If an invalid `dataframe_type` is passed - """ - if self.dataframe_type not in self._VALID_DATAFRAME_TYPES: - valid_types = ", ".join(self._VALID_DATAFRAME_TYPES) - raise DatasetError(f"`dataframe_type` must be one of {valid_types}") - - def _validate_primary_key(self): - """Validates the primary key of the table - - Raises: - DatasetError: If no `primary_key` is specified. - """ - if self.primary_key is None or len(self.primary_key) == 0: - if self.write_mode == "upsert": - raise DatasetError( - f"`primary_key` must be provided for" - f"`write_mode` {self.write_mode}" - ) - - def full_table_location(self) -> str | None: - """Returns the full table location - - Returns: - str | None : table location in the format catalog.database.table or None if database and table aren't defined - """ - full_table_location = None - if self.catalog and self.database and self.table: - full_table_location = f"`{self.catalog}`.`{self.database}`.`{self.table}`" - elif self.database and self.table: - full_table_location = f"`{self.database}`.`{self.table}`" - return full_table_location - - def schema(self) -> StructType | None: - """Returns the Spark schema of the table if it exists - - Returns: - StructType: - """ - schema = None - try: - if self.json_schema is not None: - schema = StructType.fromJson(self.json_schema) - except (KeyError, ValueError) as exc: - raise DatasetError(exc) from exc - return schema - - -class ManagedTableDataset(AbstractVersionedDataset): - """``ManagedTableDataset`` loads and saves data into managed delta tables on Databricks. +class ManagedTableDataset(BaseTableDataset): + """``ManagedTableDataset`` loads and saves data into managed delta tables in Databricks. Load and save can be in Spark or Pandas dataframes, specified in dataframe_type. - When saving data, you can specify one of three modes: overwrite(default), append, + When saving data, you can specify one of three modes: overwrite, append, or upsert. Upsert requires you to specify the primary_column parameter which will be used as part of the join condition. This dataset works best with the databricks kedro starter. That starter comes with hooks that allow this @@ -210,12 +81,6 @@ class ManagedTableDataset(AbstractVersionedDataset): >>> assert Row(name="Bob", age=12) in reloaded.take(4) """ - # this dataset cannot be used with ``ParallelRunner``, - # therefore it has the attribute ``_SINGLE_PROCESS = True`` - # for parallelism within a Spark pipeline please consider - # using ``ThreadRunner`` instead - _SINGLE_PROCESS = True - def __init__( # noqa: PLR0913 self, *, @@ -236,10 +101,10 @@ def __init__( # noqa: PLR0913 """Creates a new instance of ``ManagedTableDataset``. Args: - table: the name of the table - catalog: the name of the catalog in Unity. + table: The name of the table. + catalog: The name of the catalog in Unity. Defaults to None. - database: the name of the database. + database: The name of the database. (also referred to as schema). Defaults to "default". write_mode: the mode to write the data into the table. If not present, the dataset is read-only. @@ -248,213 +113,91 @@ def __init__( # noqa: PLR0913 Defaults to None. dataframe_type: "pandas" or "spark" dataframe. Defaults to "spark". - primary_key: the primary key of the table. + primary_key: The primary key of the table. Can be in the form of a list. Defaults to None. version: kedro.io.core.Version instance to load the data. Defaults to None. - schema: the schema of the table in JSON form. + schema: The schema of the table in JSON form. Dataframes will be truncated to match the schema if provided. - Used by the hooks to create the table if the schema is provided + Used by the hooks to create the table if the schema is provided. Defaults to None. - partition_columns: the columns to use for partitioning the table. + partition_columns: The columns to use for partitioning the table. Used by the hooks. Defaults to None. - owner_group: if table access control is enabled in your workspace, + owner_group: If table access control is enabled in your workspace, specifying owner_group will transfer ownership of the table and database to this owner. All databases should have the same owner_group. Defaults to None. metadata: Any arbitrary metadata. This is ignored by Kedro, but may be consumed by users or external plugins. Raises: - DatasetError: Invalid configuration supplied (through ManagedTable validation) + DatasetError: Invalid configuration supplied (through ``ManagedTable`` validation). """ - - self._table = ManagedTable( + super().__init__( database=database, catalog=catalog, table=table, write_mode=write_mode, dataframe_type=dataframe_type, + version=version, + schema=schema, + partition_columns=partition_columns, + metadata=metadata, primary_key=primary_key, owner_group=owner_group, - partition_columns=partition_columns, - json_schema=schema, - ) - - self._version = version - self.metadata = metadata - - super().__init__( - filepath=None, # type: ignore[arg-type] - version=version, - exists_function=self._exists, # type: ignore[arg-type] ) - def load(self) -> DataFrame | pd.DataFrame: - """Loads the version of data in the format defined in the init - (spark|pandas dataframe) + def _create_table( # noqa: PLR0913 + self, + table: str, + catalog: str | None, + database: str, + format: str, + write_mode: str | None, + location: str | None, + dataframe_type: str, + primary_key: str | list[str] | None, + json_schema: dict[str, Any] | None, + partition_columns: list[str] | None, + owner_group: str | None, + ) -> ManagedTable: + """Creates a new ``ManagedTable`` instance with the provided attributes. - Raises: - VersionNotFoundError: if the version defined in - the init doesn't exist + Args: + table: The name of the table. + catalog: The catalog of the table. + database: The database of the table. + format: The format of the table. + write_mode: The write mode for the table. + dataframe_type: The type of dataframe. + primary_key: The primary key of the table. + json_schema: The JSON schema of the table. + partition_columns: The partition columns of the table. + owner_group: The owner group of the table. Returns: - DataFrame | pd.DataFrame: Returns a dataframe - in the format defined in the init + ``ManagedTable``: The new ``ManagedTable`` instance. """ - if self._version and self._version.load >= 0: - try: - data = ( - _get_spark() - .read.format("delta") - .option("versionAsOf", self._version.load) - .table(self._table.full_table_location()) - ) - except Exception as exc: - raise VersionNotFoundError(self._version.load) from exc - else: - data = _get_spark().table(self._table.full_table_location()) - if self._table.dataframe_type == "pandas": - data = data.toPandas() - return data - - def _save_append(self, data: DataFrame) -> None: - """Saves the data to the table by appending it - to the location defined in the init - - Args: - data (DataFrame): the Spark dataframe to append to the table - """ - data.write.format("delta").mode("append").saveAsTable( - self._table.full_table_location() or "" + return ManagedTable( + table=table, + catalog=catalog, + database=database, + write_mode=write_mode, + location=location, + dataframe_type=dataframe_type, + json_schema=json_schema, + partition_columns=partition_columns, + owner_group=owner_group, + primary_key=primary_key, + format=format, ) - def _save_overwrite(self, data: DataFrame) -> None: - """Overwrites the data in the table with the data provided. - (this is the default save mode) - - Args: - data (DataFrame): the Spark dataframe to overwrite the table with. - """ - delta_table = data.write.format("delta") - if self._table.write_mode == "overwrite": - delta_table = delta_table.mode("overwrite").option( - "overwriteSchema", "true" - ) - delta_table.saveAsTable(self._table.full_table_location() or "") - - def _save_upsert(self, update_data: DataFrame) -> None: - """Upserts the data by joining on primary_key columns or column. - If table doesn't exist at save, the data is inserted to a new table. - - Args: - update_data (DataFrame): the Spark dataframe to upsert - """ - if self._exists(): - base_data = _get_spark().table(self._table.full_table_location()) - base_columns = base_data.columns - update_columns = update_data.columns - - if set(update_columns) != set(base_columns): - raise DatasetError( - f"Upsert requires tables to have identical columns. " - f"Delta table {self._table.full_table_location()} " - f"has columns: {base_columns}, whereas " - f"dataframe has columns {update_columns}" - ) - - where_expr = "" - if isinstance(self._table.primary_key, str): - where_expr = ( - f"base.{self._table.primary_key}=update.{self._table.primary_key}" - ) - elif isinstance(self._table.primary_key, list): - where_expr = " AND ".join( - f"base.{col}=update.{col}" for col in self._table.primary_key - ) - - update_data.createOrReplaceTempView("update") - _get_spark().conf.set("fullTableAddress", self._table.full_table_location()) - _get_spark().conf.set("whereExpr", where_expr) - upsert_sql = """MERGE INTO ${fullTableAddress} base USING update ON ${whereExpr} - WHEN MATCHED THEN UPDATE SET * WHEN NOT MATCHED THEN INSERT *""" - _get_spark().sql(upsert_sql) - else: - self._save_append(update_data) - - def save(self, data: DataFrame | pd.DataFrame) -> None: - """Saves the data based on the write_mode and dataframe_type in the init. - If write_mode is pandas, Spark dataframe is created first. - If schema is provided, data is matched to schema before saving - (columns will be sorted and truncated). - - Args: - data (Any): Spark or pandas dataframe to save to the table location - """ - if self._table.write_mode is None: - raise DatasetError( - "'save' can not be used in read-only mode. " - "Change 'write_mode' value to `overwrite`, `upsert` or `append`." - ) - # filter columns specified in schema and match their ordering - schema = self._table.schema() - if schema: - cols = schema.fieldNames() - if self._table.dataframe_type == "pandas": - data = _get_spark().createDataFrame( - data.loc[:, cols], schema=self._table.schema() - ) - else: - data = data.select(*cols) - elif self._table.dataframe_type == "pandas": - data = _get_spark().createDataFrame(data) - if self._table.write_mode == "overwrite": - self._save_overwrite(data) - elif self._table.write_mode == "upsert": - self._save_upsert(data) - elif self._table.write_mode == "append": - self._save_append(data) - def _describe(self) -> dict[str, str | list | None]: - """Returns a description of the instance of ManagedTableDataset + """Returns a description of the instance of the dataset. Returns: - Dict[str, str]: Dict with the details of the dataset + Dict[str, str]: Dict with the details of the dataset. """ - return { - "catalog": self._table.catalog, - "database": self._table.database, - "table": self._table.table, - "write_mode": self._table.write_mode, - "dataframe_type": self._table.dataframe_type, - "primary_key": self._table.primary_key, - "version": str(self._version), - "owner_group": self._table.owner_group, - "partition_columns": self._table.partition_columns, - } - - def _exists(self) -> bool: - """Checks to see if the table exists + description = super()._describe() + del description["format"] + del description["location"] - Returns: - bool: boolean of whether the table defined - in the dataset instance exists in the Spark session - """ - if self._table.catalog: - try: - _get_spark().sql(f"USE CATALOG `{self._table.catalog}`") - except (ParseException, AnalysisException) as exc: - logger.warning( - "catalog %s not found or unity not enabled. Error message: %s", - self._table.catalog, - exc, - ) - try: - return ( - _get_spark() - .sql(f"SHOW TABLES IN `{self._table.database}`") - .filter(f"tableName = '{self._table.table}'") - .count() - > 0 - ) - except (ParseException, AnalysisException) as exc: - logger.warning("error occured while trying to find table: %s", exc) - return False + return description diff --git a/kedro-datasets/tests/databricks/test_base_table_dataset.py b/kedro-datasets/tests/databricks/test_base_table_dataset.py new file mode 100644 index 000000000..5cc88e8df --- /dev/null +++ b/kedro-datasets/tests/databricks/test_base_table_dataset.py @@ -0,0 +1,576 @@ +import os + +import pandas as pd +import pytest +from kedro.io.core import DatasetError, Version, VersionNotFoundError +from pyspark.sql import DataFrame, SparkSession +from pyspark.sql.types import IntegerType, StringType, StructField, StructType + +from kedro_datasets.databricks._base_table_dataset import BaseTableDataset + + +@pytest.fixture +def sample_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [("Alex", 31), ("Bob", 12), ("Clarke", 65), ("Dave", 29)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def upsert_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [("Alex", 32), ("Evan", 23)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def mismatched_upsert_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + StructField("height", IntegerType(), True), + ] + ) + + data = [("Alex", 32, 174), ("Evan", 23, 166)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def subset_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + StructField("height", IntegerType(), True), + ] + ) + + data = [("Alex", 32, 174), ("Evan", 23, 166)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def subset_pandas_df(): + return pd.DataFrame( + {"name": ["Alex", "Evan"], "age": [32, 23], "height": [174, 166]} + ) + + +@pytest.fixture +def subset_expected_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [("Alex", 32), ("Evan", 23)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def sample_pandas_df(): + return pd.DataFrame( + {"name": ["Alex", "Bob", "Clarke", "Dave"], "age": [31, 12, 65, 29]} + ) + + +@pytest.fixture +def append_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [("Evan", 23), ("Frank", 13)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def expected_append_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [ + ("Alex", 31), + ("Bob", 12), + ("Clarke", 65), + ("Dave", 29), + ("Evan", 23), + ("Frank", 13), + ] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def expected_upsert_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [ + ("Alex", 32), + ("Bob", 12), + ("Clarke", 65), + ("Dave", 29), + ("Evan", 23), + ] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def expected_upsert_multiple_primary_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [ + ("Alex", 31), + ("Alex", 32), + ("Bob", 12), + ("Clarke", 65), + ("Dave", 29), + ("Evan", 23), + ] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def external_location(): + return os.environ.get("DATABRICKS_EXTERNAL_LOCATION") + + +class TestBaseTableDataset: + def test_full_table(self): + unity_ds = BaseTableDataset(catalog="test", database="test", table="test") + assert unity_ds._table.full_table_location() == "`test`.`test`.`test`" + + unity_ds = BaseTableDataset(catalog="test-test", database="test", table="test") + assert unity_ds._table.full_table_location() == "`test-test`.`test`.`test`" + + unity_ds = BaseTableDataset(database="test", table="test") + assert unity_ds._table.full_table_location() == "`test`.`test`" + + unity_ds = BaseTableDataset(table="test") + assert unity_ds._table.full_table_location() == "`default`.`test`" + + with pytest.raises(TypeError): + BaseTableDataset() + + def test_describe(self): + unity_ds = BaseTableDataset(table="test") + assert unity_ds._describe() == { + "catalog": None, + "database": "default", + "table": "test", + "format": "delta", + "location": None, + "write_mode": None, + "dataframe_type": "spark", + "primary_key": None, + "version": "None", + "owner_group": None, + "partition_columns": None, + } + + def test_invalid_write_mode(self): + with pytest.raises(DatasetError): + BaseTableDataset(table="test", write_mode="invalid") + + def test_dataframe_type(self): + with pytest.raises(DatasetError): + BaseTableDataset(table="test", dataframe_type="invalid") + + def test_missing_primary_key_upsert(self): + with pytest.raises(DatasetError): + BaseTableDataset(table="test", write_mode="upsert") + + def test_invalid_table_name(self): + with pytest.raises(DatasetError): + BaseTableDataset(table="invalid!") + + def test_invalid_database(self): + with pytest.raises(DatasetError): + BaseTableDataset(table="test", database="invalid!") + + def test_invalid_catalog(self): + with pytest.raises(DatasetError): + BaseTableDataset(table="test", catalog="invalid!") + + def test_invalid_format(self): + with pytest.raises(DatasetError): + BaseTableDataset(table="test", format="invalid") + + def test_schema(self): + unity_ds = BaseTableDataset( + table="test", + schema={ + "fields": [ + { + "metadata": {}, + "name": "name", + "nullable": True, + "type": "string", + }, + { + "metadata": {}, + "name": "age", + "nullable": True, + "type": "integer", + }, + ], + "type": "struct", + }, + ) + expected_schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + assert unity_ds._table.schema() == expected_schema + + def test_invalid_schema(self): + with pytest.raises(DatasetError): + BaseTableDataset( + table="test", + schema={ + "fields": [ + { + "invalid": "schema", + } + ], + "type": "struct", + }, + )._table.schema() + + def test_catalog_exists(self): + unity_ds = BaseTableDataset( + catalog="test", database="invalid", table="test_not_there" + ) + assert not unity_ds._exists() + + def test_table_does_not_exist(self): + unity_ds = BaseTableDataset(database="invalid", table="test_not_there") + assert not unity_ds._exists() + + def test_save_default(self, sample_spark_df: DataFrame): + unity_ds = BaseTableDataset(database="test", table="test_save") + with pytest.raises(DatasetError): + unity_ds.save(sample_spark_df) + + def test_save_schema_spark( + self, subset_spark_df: DataFrame, subset_expected_df: DataFrame + ): + unity_ds = BaseTableDataset( + database="test", + table="test_save_spark_schema", + schema={ + "fields": [ + { + "metadata": {}, + "name": "name", + "nullable": True, + "type": "string", + }, + { + "metadata": {}, + "name": "age", + "nullable": True, + "type": "integer", + }, + ], + "type": "struct", + }, + write_mode="overwrite", + ) + unity_ds.save(subset_spark_df) + saved_table = unity_ds.load() + assert subset_expected_df.exceptAll(saved_table).count() == 0 + + def test_save_schema_pandas( + self, subset_pandas_df: pd.DataFrame, subset_expected_df: DataFrame + ): + unity_ds = BaseTableDataset( + database="test", + table="test_save_pd_schema", + schema={ + "fields": [ + { + "metadata": {}, + "name": "name", + "nullable": True, + "type": "string", + }, + { + "metadata": {}, + "name": "age", + "nullable": True, + "type": "integer", + }, + ], + "type": "struct", + }, + write_mode="overwrite", + dataframe_type="pandas", + ) + unity_ds.save(subset_pandas_df) + saved_ds = BaseTableDataset( + database="test", + table="test_save_pd_schema", + ) + saved_table = saved_ds.load() + assert subset_expected_df.exceptAll(saved_table).count() == 0 + + def test_save_overwrite( + self, sample_spark_df: DataFrame, append_spark_df: DataFrame + ): + unity_ds = BaseTableDataset( + database="test", table="test_save", write_mode="overwrite" + ) + unity_ds.save(sample_spark_df) + unity_ds.save(append_spark_df) + + overwritten_table = unity_ds.load() + + assert append_spark_df.exceptAll(overwritten_table).count() == 0 + + def test_save_overwrite_partitioned( + self, sample_spark_df: DataFrame, append_spark_df: DataFrame + ): + unity_ds = BaseTableDataset( + database="test", + table="test_save_partitioned", + write_mode="overwrite", + partition_columns=["name"], + ) + unity_ds.save(sample_spark_df) + unity_ds.save(append_spark_df) + + overwritten_table = unity_ds.load() + + assert append_spark_df.exceptAll(overwritten_table).count() == 0 + + def test_save_overwrite_external( + self, + sample_spark_df: DataFrame, + append_spark_df: DataFrame, + external_location: str, + ): + unity_ds = BaseTableDataset( + database="test", + table="test_save_external", + write_mode="overwrite", + location=f"{external_location}/test_save_external", + ) + unity_ds.save(sample_spark_df) + unity_ds.save(append_spark_df) + + overwritten_table = unity_ds.load() + + assert append_spark_df.exceptAll(overwritten_table).count() == 0 + + def test_save_append( + self, + sample_spark_df: DataFrame, + append_spark_df: DataFrame, + expected_append_spark_df: DataFrame, + ): + unity_ds = BaseTableDataset( + database="test", table="test_save_append", write_mode="append" + ) + unity_ds.save(sample_spark_df) + unity_ds.save(append_spark_df) + + appended_table = unity_ds.load() + + assert expected_append_spark_df.exceptAll(appended_table).count() == 0 + + def test_save_append_partitioned( + self, + sample_spark_df: DataFrame, + append_spark_df: DataFrame, + expected_append_spark_df: DataFrame, + ): + unity_ds = BaseTableDataset( + database="test", + table="test_save_append_partitioned", + write_mode="append", + partition_columns=["name"], + ) + unity_ds.save(sample_spark_df) + unity_ds.save(append_spark_df) + + appended_table = unity_ds.load() + + assert expected_append_spark_df.exceptAll(appended_table).count() == 0 + + def test_save_append_external( + self, + sample_spark_df: DataFrame, + append_spark_df: DataFrame, + expected_append_spark_df: DataFrame, + external_location: str, + ): + unity_ds = BaseTableDataset( + database="test", + table="test_save_append_external", + write_mode="append", + location=f"{external_location}/test_save_append_external", + ) + unity_ds.save(sample_spark_df) + unity_ds.save(append_spark_df) + + appended_table = unity_ds.load() + + assert expected_append_spark_df.exceptAll(appended_table).count() == 0 + + def test_save_upsert( + self, + sample_spark_df: DataFrame, + upsert_spark_df: DataFrame, + expected_upsert_spark_df: DataFrame, + ): + unity_ds = BaseTableDataset( + database="test", + table="test_save_upsert", + write_mode="upsert", + primary_key="name", + ) + unity_ds.save(sample_spark_df) + unity_ds.save(upsert_spark_df) + + upserted_table = unity_ds.load() + + assert expected_upsert_spark_df.exceptAll(upserted_table).count() == 0 + + def test_save_upsert_multiple_primary( + self, + sample_spark_df: DataFrame, + upsert_spark_df: DataFrame, + expected_upsert_multiple_primary_spark_df: DataFrame, + ): + unity_ds = BaseTableDataset( + database="test", + table="test_save_upsert_multiple", + write_mode="upsert", + primary_key=["name", "age"], + ) + unity_ds.save(sample_spark_df) + unity_ds.save(upsert_spark_df) + + upserted_table = unity_ds.load() + + assert ( + expected_upsert_multiple_primary_spark_df.exceptAll(upserted_table).count() + == 0 + ) + + def test_save_upsert_mismatched_columns( + self, + sample_spark_df: DataFrame, + mismatched_upsert_spark_df: DataFrame, + ): + unity_ds = BaseTableDataset( + database="test", + table="test_save_upsert_mismatch", + write_mode="upsert", + primary_key="name", + ) + unity_ds.save(sample_spark_df) + with pytest.raises(DatasetError): + unity_ds.save(mismatched_upsert_spark_df) + + def test_load_spark(self, sample_spark_df: DataFrame): + unity_ds = BaseTableDataset( + database="test", table="test_load_spark", write_mode="overwrite" + ) + unity_ds.save(sample_spark_df) + + delta_ds = BaseTableDataset(database="test", table="test_load_spark") + delta_table = delta_ds.load() + + assert ( + isinstance(delta_table, DataFrame) + and delta_table.exceptAll(sample_spark_df).count() == 0 + ) + + def test_load_spark_no_version(self, sample_spark_df: DataFrame): + unity_ds = BaseTableDataset( + database="test", table="test_load_spark", write_mode="overwrite" + ) + unity_ds.save(sample_spark_df) + + delta_ds = BaseTableDataset( + database="test", table="test_load_spark", version=Version(2, None) + ) + with pytest.raises(VersionNotFoundError): + _ = delta_ds.load() + + def test_load_version(self, sample_spark_df: DataFrame, append_spark_df: DataFrame): + unity_ds = BaseTableDataset( + database="test", table="test_load_version", write_mode="append" + ) + unity_ds.save(sample_spark_df) + unity_ds.save(append_spark_df) + + loaded_ds = BaseTableDataset( + database="test", table="test_load_version", version=Version(0, None) + ) + loaded_df = loaded_ds.load() + + assert loaded_df.exceptAll(sample_spark_df).count() == 0 + + def test_load_pandas(self, sample_pandas_df: pd.DataFrame): + unity_ds = BaseTableDataset( + database="test", + table="test_load_pandas", + dataframe_type="pandas", + write_mode="overwrite", + ) + unity_ds.save(sample_pandas_df) + + pandas_ds = BaseTableDataset( + database="test", table="test_load_pandas", dataframe_type="pandas" + ) + pandas_df = pandas_ds.load().sort_values("name", ignore_index=True) + + assert isinstance(pandas_df, pd.DataFrame) and pandas_df.equals( + sample_pandas_df + ) diff --git a/kedro-datasets/tests/databricks/test_managed_table_dataset.py b/kedro-datasets/tests/databricks/test_managed_table_dataset.py index 03a85d27e..c3cc623f4 100644 --- a/kedro-datasets/tests/databricks/test_managed_table_dataset.py +++ b/kedro-datasets/tests/databricks/test_managed_table_dataset.py @@ -1,7 +1,6 @@ import pandas as pd import pytest -from kedro.io.core import DatasetError, Version, VersionNotFoundError -from pyspark.sql import DataFrame, SparkSession +from pyspark.sql import SparkSession from pyspark.sql.types import IntegerType, StringType, StructField, StructType from kedro_datasets.databricks import ManagedTableDataset @@ -170,24 +169,6 @@ def expected_upsert_multiple_primary_spark_df(spark_session: SparkSession): class TestManagedTableDataset: - def test_full_table(self): - unity_ds = ManagedTableDataset(catalog="test", database="test", table="test") - assert unity_ds._table.full_table_location() == "`test`.`test`.`test`" - - unity_ds = ManagedTableDataset( - catalog="test-test", database="test", table="test" - ) - assert unity_ds._table.full_table_location() == "`test-test`.`test`.`test`" - - unity_ds = ManagedTableDataset(database="test", table="test") - assert unity_ds._table.full_table_location() == "`test`.`test`" - - unity_ds = ManagedTableDataset(table="test") - assert unity_ds._table.full_table_location() == "`default`.`test`" - - with pytest.raises(TypeError): - ManagedTableDataset() - def test_describe(self): unity_ds = ManagedTableDataset(table="test") assert unity_ds._describe() == { @@ -201,291 +182,3 @@ def test_describe(self): "owner_group": None, "partition_columns": None, } - - def test_invalid_write_mode(self): - with pytest.raises(DatasetError): - ManagedTableDataset(table="test", write_mode="invalid") - - def test_dataframe_type(self): - with pytest.raises(DatasetError): - ManagedTableDataset(table="test", dataframe_type="invalid") - - def test_missing_primary_key_upsert(self): - with pytest.raises(DatasetError): - ManagedTableDataset(table="test", write_mode="upsert") - - def test_invalid_table_name(self): - with pytest.raises(DatasetError): - ManagedTableDataset(table="invalid!") - - def test_invalid_database(self): - with pytest.raises(DatasetError): - ManagedTableDataset(table="test", database="invalid!") - - def test_invalid_catalog(self): - with pytest.raises(DatasetError): - ManagedTableDataset(table="test", catalog="invalid!") - - def test_schema(self): - unity_ds = ManagedTableDataset( - table="test", - schema={ - "fields": [ - { - "metadata": {}, - "name": "name", - "nullable": True, - "type": "string", - }, - { - "metadata": {}, - "name": "age", - "nullable": True, - "type": "integer", - }, - ], - "type": "struct", - }, - ) - expected_schema = StructType( - [ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - ] - ) - assert unity_ds._table.schema() == expected_schema - - def test_invalid_schema(self): - with pytest.raises(DatasetError): - ManagedTableDataset( - table="test", - schema={ - "fields": [ - { - "invalid": "schema", - } - ], - "type": "struct", - }, - )._table.schema() - - def test_catalog_exists(self): - unity_ds = ManagedTableDataset( - catalog="test", database="invalid", table="test_not_there" - ) - assert not unity_ds._exists() - - def test_table_does_not_exist(self): - unity_ds = ManagedTableDataset(database="invalid", table="test_not_there") - assert not unity_ds._exists() - - def test_save_default(self, sample_spark_df: DataFrame): - unity_ds = ManagedTableDataset(database="test", table="test_save") - with pytest.raises(DatasetError): - unity_ds.save(sample_spark_df) - - def test_save_schema_spark( - self, subset_spark_df: DataFrame, subset_expected_df: DataFrame - ): - unity_ds = ManagedTableDataset( - database="test", - table="test_save_spark_schema", - schema={ - "fields": [ - { - "metadata": {}, - "name": "name", - "nullable": True, - "type": "string", - }, - { - "metadata": {}, - "name": "age", - "nullable": True, - "type": "integer", - }, - ], - "type": "struct", - }, - write_mode="overwrite", - ) - unity_ds.save(subset_spark_df) - saved_table = unity_ds.load() - assert subset_expected_df.exceptAll(saved_table).count() == 0 - - def test_save_schema_pandas( - self, subset_pandas_df: pd.DataFrame, subset_expected_df: DataFrame - ): - unity_ds = ManagedTableDataset( - database="test", - table="test_save_pd_schema", - schema={ - "fields": [ - { - "metadata": {}, - "name": "name", - "nullable": True, - "type": "string", - }, - { - "metadata": {}, - "name": "age", - "nullable": True, - "type": "integer", - }, - ], - "type": "struct", - }, - write_mode="overwrite", - dataframe_type="pandas", - ) - unity_ds.save(subset_pandas_df) - saved_ds = ManagedTableDataset( - database="test", - table="test_save_pd_schema", - ) - saved_table = saved_ds.load() - assert subset_expected_df.exceptAll(saved_table).count() == 0 - - def test_save_overwrite( - self, sample_spark_df: DataFrame, append_spark_df: DataFrame - ): - unity_ds = ManagedTableDataset( - database="test", table="test_save", write_mode="overwrite" - ) - unity_ds.save(sample_spark_df) - unity_ds.save(append_spark_df) - - overwritten_table = unity_ds.load() - - assert append_spark_df.exceptAll(overwritten_table).count() == 0 - - def test_save_append( - self, - sample_spark_df: DataFrame, - append_spark_df: DataFrame, - expected_append_spark_df: DataFrame, - ): - unity_ds = ManagedTableDataset( - database="test", table="test_save_append", write_mode="append" - ) - unity_ds.save(sample_spark_df) - unity_ds.save(append_spark_df) - - appended_table = unity_ds.load() - - assert expected_append_spark_df.exceptAll(appended_table).count() == 0 - - def test_save_upsert( - self, - sample_spark_df: DataFrame, - upsert_spark_df: DataFrame, - expected_upsert_spark_df: DataFrame, - ): - unity_ds = ManagedTableDataset( - database="test", - table="test_save_upsert", - write_mode="upsert", - primary_key="name", - ) - unity_ds.save(sample_spark_df) - unity_ds.save(upsert_spark_df) - - upserted_table = unity_ds.load() - - assert expected_upsert_spark_df.exceptAll(upserted_table).count() == 0 - - def test_save_upsert_multiple_primary( - self, - sample_spark_df: DataFrame, - upsert_spark_df: DataFrame, - expected_upsert_multiple_primary_spark_df: DataFrame, - ): - unity_ds = ManagedTableDataset( - database="test", - table="test_save_upsert_multiple", - write_mode="upsert", - primary_key=["name", "age"], - ) - unity_ds.save(sample_spark_df) - unity_ds.save(upsert_spark_df) - - upserted_table = unity_ds.load() - - assert ( - expected_upsert_multiple_primary_spark_df.exceptAll(upserted_table).count() - == 0 - ) - - def test_save_upsert_mismatched_columns( - self, - sample_spark_df: DataFrame, - mismatched_upsert_spark_df: DataFrame, - ): - unity_ds = ManagedTableDataset( - database="test", - table="test_save_upsert_mismatch", - write_mode="upsert", - primary_key="name", - ) - unity_ds.save(sample_spark_df) - with pytest.raises(DatasetError): - unity_ds.save(mismatched_upsert_spark_df) - - def test_load_spark(self, sample_spark_df: DataFrame): - unity_ds = ManagedTableDataset( - database="test", table="test_load_spark", write_mode="overwrite" - ) - unity_ds.save(sample_spark_df) - - delta_ds = ManagedTableDataset(database="test", table="test_load_spark") - delta_table = delta_ds.load() - - assert ( - isinstance(delta_table, DataFrame) - and delta_table.exceptAll(sample_spark_df).count() == 0 - ) - - def test_load_spark_no_version(self, sample_spark_df: DataFrame): - unity_ds = ManagedTableDataset( - database="test", table="test_load_spark", write_mode="overwrite" - ) - unity_ds.save(sample_spark_df) - - delta_ds = ManagedTableDataset( - database="test", table="test_load_spark", version=Version(2, None) - ) - with pytest.raises(VersionNotFoundError): - _ = delta_ds.load() - - def test_load_version(self, sample_spark_df: DataFrame, append_spark_df: DataFrame): - unity_ds = ManagedTableDataset( - database="test", table="test_load_version", write_mode="append" - ) - unity_ds.save(sample_spark_df) - unity_ds.save(append_spark_df) - - loaded_ds = ManagedTableDataset( - database="test", table="test_load_version", version=Version(0, None) - ) - loaded_df = loaded_ds.load() - - assert loaded_df.exceptAll(sample_spark_df).count() == 0 - - def test_load_pandas(self, sample_pandas_df: pd.DataFrame): - unity_ds = ManagedTableDataset( - database="test", - table="test_load_pandas", - dataframe_type="pandas", - write_mode="overwrite", - ) - unity_ds.save(sample_pandas_df) - - pandas_ds = ManagedTableDataset( - database="test", table="test_load_pandas", dataframe_type="pandas" - ) - pandas_df = pandas_ds.load().sort_values("name", ignore_index=True) - - assert isinstance(pandas_df, pd.DataFrame) and pandas_df.equals( - sample_pandas_df - )