From 5a05a3468962a2147ae5b754dc23f3739d2b3cda Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Fri, 4 Oct 2024 19:57:58 +0530 Subject: [PATCH] fixed lint issues Signed-off-by: Minura Punchihewa --- .../databricks/_base_table_dataset.py | 49 ++++++++++++------- .../databricks/managed_table_dataset.py | 16 +++--- .../databricks/test_base_table_dataset.py | 4 +- .../databricks/test_managed_table_dataset.py | 4 +- 4 files changed, 39 insertions(+), 34 deletions(-) diff --git a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py index cba1798a3..de893f633 100644 --- a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py @@ -6,17 +6,16 @@ import logging import re from dataclasses import dataclass, field -from typing import Any, ClassVar, List +from typing import Any, ClassVar import pandas as pd from kedro.io.core import ( AbstractVersionedDataset, DatasetError, Version, - VersionNotFoundError + VersionNotFoundError, ) from pyspark.sql import DataFrame -from pyspark.sql.readwriter import DataFrameWriter from pyspark.sql.types import StructType from pyspark.sql.utils import AnalysisException, ParseException @@ -29,14 +28,19 @@ @dataclass(frozen=True) class BaseTable: """Stores the definition of a base table. - + Acts as the base class for `ManagedTable` and `ExternalTable`. """ + # regex for tables, catalogs and schemas _NAMING_REGEX: ClassVar[str] = r"\b[0-9a-zA-Z_-]{1,}\b" - _VALID_WRITE_MODES: ClassVar[List[str]] = field(default=["overwrite", "upsert", "append"]) - _VALID_DATAFRAME_TYPES: ClassVar[List[str]] = field(default=["spark", "pandas"]) - _VALID_FORMATS: ClassVar[List[str]] = field(default=["delta", "parquet", "csv", "json", "orc", "avro", "text"]) + _VALID_WRITE_MODES: ClassVar[list[str]] = field( + default=["overwrite", "upsert", "append"] + ) + _VALID_DATAFRAME_TYPES: ClassVar[list[str]] = field(default=["spark", "pandas"]) + _VALID_FORMATS: ClassVar[list[str]] = field( + default=["delta", "parquet", "csv", "json", "orc", "avro", "text"] + ) database: str catalog: str | None @@ -47,7 +51,7 @@ class BaseTable: primary_key: str | list[str] | None owner_group: str | None partition_columns: str | list[str] | None - format: str = "delta", + format: str = ("delta",) json_schema: dict[str, Any] | None = None def __post_init__(self): @@ -229,7 +233,7 @@ def __init__( # noqa: PLR0913 schema: dict[str, Any] | None = None, partition_columns: list[str] | None = None, owner_group: str | None = None, - metadata: dict[str, Any] | None = None + metadata: dict[str, Any] | None = None, ) -> None: """Creates a new instance of ``BaseTableDataset``. @@ -306,7 +310,7 @@ def _create_table( # noqa: PLR0913 primary_key: str | list[str] | None, json_schema: dict[str, Any] | None, partition_columns: list[str] | None, - owner_group: str | None + owner_group: str | None, ) -> BaseTable: """Creates a ``BaseTable`` instance with the provided attributes. @@ -315,6 +319,7 @@ def _create_table( # noqa: PLR0913 catalog: The catalog of the table. database: The database of the table. format: The format of the table. + location: The location of the table. write_mode: The write mode for the table. dataframe_type: The type of dataframe. primary_key: The primary key of the table. @@ -338,7 +343,7 @@ def _create_table( # noqa: PLR0913 owner_group=owner_group, primary_key=primary_key, ) - + def _load(self) -> DataFrame | pd.DataFrame: """Loads the version of data in the format defined in the init (spark|pandas dataframe). @@ -366,7 +371,7 @@ def _load(self) -> DataFrame | pd.DataFrame: if self._table.dataframe_type == "pandas": data = data.toPandas() return data - + def _save(self, data: DataFrame | pd.DataFrame) -> None: """Saves the data based on the write_mode and dataframe_type in the init. If write_mode is pandas, Spark dataframe is created first. @@ -401,9 +406,9 @@ def _save(self, data: DataFrame | pd.DataFrame) -> None: f"Invalid `write_mode` provided: {self._table.write_mode}. " f"`write_mode` must be one of: {self._table._VALID_WRITE_MODES}" ) - + method(data) - + def _save_append(self, data: DataFrame) -> None: """Saves the data to the table by appending it to the location defined in the init. @@ -415,7 +420,9 @@ def _save_append(self, data: DataFrame) -> None: if self._table.partition_columns: writer.partitionBy( - *self._table.partition_columns if isinstance(self._table.partition_columns, list) else [self._table.partition_columns] + *self._table.partition_columns + if isinstance(self._table.partition_columns, list) + else [self._table.partition_columns] ) if self._table.location: @@ -429,13 +436,17 @@ def _save_overwrite(self, data: DataFrame) -> None: Args: data (DataFrame): The Spark dataframe to overwrite the table with. """ - writer = data.write.format(self._table.format).mode("overwrite").option( - "overwriteSchema", "true" + writer = ( + data.write.format(self._table.format) + .mode("overwrite") + .option("overwriteSchema", "true") ) - + if self._table.partition_columns: writer.partitionBy( - *self._table.partition_columns if isinstance(self._table.partition_columns, list) else [self._table.partition_columns] + *self._table.partition_columns + if isinstance(self._table.partition_columns, list) + else [self._table.partition_columns] ) if self._table.location: diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index fd4a3964e..6185a498b 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -5,12 +5,10 @@ import logging from dataclasses import dataclass, field -from typing import Any, ClassVar, List +from typing import Any, ClassVar import pandas as pd -from kedro.io.core import ( - Version -) +from kedro.io.core import Version from kedro_datasets.databricks._base_table_dataset import BaseTable, BaseTableDataset @@ -22,7 +20,7 @@ class ManagedTable(BaseTable): """Stores the definition of a managed table.""" - _VALID_FORMATS: ClassVar[List[str]] = field(default=["delta"]) + _VALID_FORMATS: ClassVar[list[str]] = field(default=["delta"]) class ManagedTableDataset(BaseTableDataset): @@ -82,6 +80,7 @@ class ManagedTableDataset(BaseTableDataset): >>> reloaded = dataset.load() >>> assert Row(name="Bob", age=12) in reloaded.take(4) """ + def __init__( # noqa: PLR0913 self, *, @@ -158,7 +157,7 @@ def _create_table( # noqa: PLR0913 primary_key: str | list[str] | None, json_schema: dict[str, Any] | None, partition_columns: list[str] | None, - owner_group: str | None + owner_group: str | None, ) -> ManagedTable: """Creates a new ``ManagedTable`` instance with the provided attributes. @@ -188,9 +187,9 @@ def _create_table( # noqa: PLR0913 partition_columns=partition_columns, owner_group=owner_group, primary_key=primary_key, - format=format + format=format, ) - + def _describe(self) -> dict[str, str | list | None]: """Returns a description of the instance of the dataset. @@ -202,4 +201,3 @@ def _describe(self) -> dict[str, str | list | None]: del description["location"] return description - diff --git a/kedro-datasets/tests/databricks/test_base_table_dataset.py b/kedro-datasets/tests/databricks/test_base_table_dataset.py index 6f0182474..dddfe794a 100644 --- a/kedro-datasets/tests/databricks/test_base_table_dataset.py +++ b/kedro-datasets/tests/databricks/test_base_table_dataset.py @@ -174,9 +174,7 @@ def test_full_table(self): unity_ds = BaseTableDataset(catalog="test", database="test", table="test") assert unity_ds._table.full_table_location() == "`test`.`test`.`test`" - unity_ds = BaseTableDataset( - catalog="test-test", database="test", table="test" - ) + unity_ds = BaseTableDataset(catalog="test-test", database="test", table="test") assert unity_ds._table.full_table_location() == "`test-test`.`test`.`test`" unity_ds = BaseTableDataset(database="test", table="test") diff --git a/kedro-datasets/tests/databricks/test_managed_table_dataset.py b/kedro-datasets/tests/databricks/test_managed_table_dataset.py index aac95cd7a..c3cc623f4 100644 --- a/kedro-datasets/tests/databricks/test_managed_table_dataset.py +++ b/kedro-datasets/tests/databricks/test_managed_table_dataset.py @@ -1,7 +1,6 @@ import pandas as pd import pytest -from kedro.io.core import DatasetError, Version, VersionNotFoundError -from pyspark.sql import DataFrame, SparkSession +from pyspark.sql import SparkSession from pyspark.sql.types import IntegerType, StringType, StructField, StructType from kedro_datasets.databricks import ManagedTableDataset @@ -170,7 +169,6 @@ def expected_upsert_multiple_primary_spark_df(spark_session: SparkSession): class TestManagedTableDataset: - def test_describe(self): unity_ds = ManagedTableDataset(table="test") assert unity_ds._describe() == {