Skip to content

Commit

Permalink
fixed lint issues
Browse files Browse the repository at this point in the history
  • Loading branch information
MinuraPunchihewa committed Oct 4, 2024
1 parent 58a8691 commit e4d8ea1
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 34 deletions.
49 changes: 30 additions & 19 deletions kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,16 @@
import logging
import re
from dataclasses import dataclass, field
from typing import Any, ClassVar, List
from typing import Any, ClassVar

import pandas as pd
from kedro.io.core import (
AbstractVersionedDataset,
DatasetError,
Version,
VersionNotFoundError
VersionNotFoundError,
)
from pyspark.sql import DataFrame
from pyspark.sql.readwriter import DataFrameWriter
from pyspark.sql.types import StructType
from pyspark.sql.utils import AnalysisException, ParseException

Expand All @@ -29,14 +28,19 @@
@dataclass(frozen=True)
class BaseTable:
"""Stores the definition of a base table.
Acts as the base class for `ManagedTable` and `ExternalTable`.
"""

# regex for tables, catalogs and schemas
_NAMING_REGEX: ClassVar[str] = r"\b[0-9a-zA-Z_-]{1,}\b"
_VALID_WRITE_MODES: ClassVar[List[str]] = field(default=["overwrite", "upsert", "append"])
_VALID_DATAFRAME_TYPES: ClassVar[List[str]] = field(default=["spark", "pandas"])
_VALID_FORMATS: ClassVar[List[str]] = field(default=["delta", "parquet", "csv", "json", "orc", "avro", "text"])
_VALID_WRITE_MODES: ClassVar[list[str]] = field(
default=["overwrite", "upsert", "append"]
)
_VALID_DATAFRAME_TYPES: ClassVar[list[str]] = field(default=["spark", "pandas"])
_VALID_FORMATS: ClassVar[list[str]] = field(
default=["delta", "parquet", "csv", "json", "orc", "avro", "text"]
)

database: str
catalog: str | None
Expand All @@ -47,7 +51,7 @@ class BaseTable:
primary_key: str | list[str] | None
owner_group: str | None
partition_columns: str | list[str] | None
format: str = "delta",
format: str = ("delta",)
json_schema: dict[str, Any] | None = None

def __post_init__(self):
Expand Down Expand Up @@ -229,7 +233,7 @@ def __init__( # noqa: PLR0913
schema: dict[str, Any] | None = None,
partition_columns: list[str] | None = None,
owner_group: str | None = None,
metadata: dict[str, Any] | None = None
metadata: dict[str, Any] | None = None,
) -> None:
"""Creates a new instance of ``BaseTableDataset``.
Expand Down Expand Up @@ -306,7 +310,7 @@ def _create_table( # noqa: PLR0913
primary_key: str | list[str] | None,
json_schema: dict[str, Any] | None,
partition_columns: list[str] | None,
owner_group: str | None
owner_group: str | None,
) -> BaseTable:
"""Creates a ``BaseTable`` instance with the provided attributes.
Expand All @@ -315,6 +319,7 @@ def _create_table( # noqa: PLR0913
catalog: The catalog of the table.
database: The database of the table.
format: The format of the table.
location: The location of the table.
write_mode: The write mode for the table.
dataframe_type: The type of dataframe.
primary_key: The primary key of the table.
Expand All @@ -338,7 +343,7 @@ def _create_table( # noqa: PLR0913
owner_group=owner_group,
primary_key=primary_key,
)

def _load(self) -> DataFrame | pd.DataFrame:
"""Loads the version of data in the format defined in the init
(spark|pandas dataframe).
Expand Down Expand Up @@ -366,7 +371,7 @@ def _load(self) -> DataFrame | pd.DataFrame:
if self._table.dataframe_type == "pandas":
data = data.toPandas()
return data

def _save(self, data: DataFrame | pd.DataFrame) -> None:
"""Saves the data based on the write_mode and dataframe_type in the init.
If write_mode is pandas, Spark dataframe is created first.
Expand Down Expand Up @@ -401,9 +406,9 @@ def _save(self, data: DataFrame | pd.DataFrame) -> None:
f"Invalid `write_mode` provided: {self._table.write_mode}. "
f"`write_mode` must be one of: {self._table._VALID_WRITE_MODES}"
)

method(data)

def _save_append(self, data: DataFrame) -> None:
"""Saves the data to the table by appending it
to the location defined in the init.
Expand All @@ -415,7 +420,9 @@ def _save_append(self, data: DataFrame) -> None:

if self._table.partition_columns:
writer.partitionBy(
*self._table.partition_columns if isinstance(self._table.partition_columns, list) else [self._table.partition_columns]
*self._table.partition_columns
if isinstance(self._table.partition_columns, list)
else [self._table.partition_columns]
)

if self._table.location:
Expand All @@ -429,13 +436,17 @@ def _save_overwrite(self, data: DataFrame) -> None:
Args:
data (DataFrame): The Spark dataframe to overwrite the table with.
"""
writer = data.write.format(self._table.format).mode("overwrite").option(
"overwriteSchema", "true"
writer = (
data.write.format(self._table.format)
.mode("overwrite")
.option("overwriteSchema", "true")
)

if self._table.partition_columns:
writer.partitionBy(
*self._table.partition_columns if isinstance(self._table.partition_columns, list) else [self._table.partition_columns]
*self._table.partition_columns
if isinstance(self._table.partition_columns, list)
else [self._table.partition_columns]
)

if self._table.location:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@

import logging
from dataclasses import dataclass, field
from typing import Any, ClassVar, List
from typing import Any, ClassVar

import pandas as pd
from kedro.io.core import (
Version
)
from kedro.io.core import Version

from kedro_datasets.databricks._base_table_dataset import BaseTable, BaseTableDataset

Expand All @@ -22,7 +20,7 @@
class ManagedTable(BaseTable):
"""Stores the definition of a managed table."""

_VALID_FORMATS: ClassVar[List[str]] = field(default=["delta"])
_VALID_FORMATS: ClassVar[list[str]] = field(default=["delta"])


class ManagedTableDataset(BaseTableDataset):
Expand Down Expand Up @@ -82,6 +80,7 @@ class ManagedTableDataset(BaseTableDataset):
>>> reloaded = dataset.load()
>>> assert Row(name="Bob", age=12) in reloaded.take(4)
"""

def __init__( # noqa: PLR0913
self,
*,
Expand Down Expand Up @@ -158,7 +157,7 @@ def _create_table( # noqa: PLR0913
primary_key: str | list[str] | None,
json_schema: dict[str, Any] | None,
partition_columns: list[str] | None,
owner_group: str | None
owner_group: str | None,
) -> ManagedTable:
"""Creates a new ``ManagedTable`` instance with the provided attributes.
Expand Down Expand Up @@ -188,9 +187,9 @@ def _create_table( # noqa: PLR0913
partition_columns=partition_columns,
owner_group=owner_group,
primary_key=primary_key,
format=format
format=format,
)

def _describe(self) -> dict[str, str | list | None]:
"""Returns a description of the instance of the dataset.
Expand All @@ -202,4 +201,3 @@ def _describe(self) -> dict[str, str | list | None]:
del description["location"]

return description

4 changes: 1 addition & 3 deletions kedro-datasets/tests/databricks/test_base_table_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,7 @@ def test_full_table(self):
unity_ds = BaseTableDataset(catalog="test", database="test", table="test")
assert unity_ds._table.full_table_location() == "`test`.`test`.`test`"

unity_ds = BaseTableDataset(
catalog="test-test", database="test", table="test"
)
unity_ds = BaseTableDataset(catalog="test-test", database="test", table="test")
assert unity_ds._table.full_table_location() == "`test-test`.`test`.`test`"

unity_ds = BaseTableDataset(database="test", table="test")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pandas as pd
import pytest
from kedro.io.core import DatasetError, Version, VersionNotFoundError
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql import SparkSession
from pyspark.sql.types import IntegerType, StringType, StructField, StructType

from kedro_datasets.databricks import ManagedTableDataset
Expand Down Expand Up @@ -170,7 +169,6 @@ def expected_upsert_multiple_primary_spark_df(spark_session: SparkSession):


class TestManagedTableDataset:

def test_describe(self):
unity_ds = ManagedTableDataset(table="test")
assert unity_ds._describe() == {
Expand Down

0 comments on commit e4d8ea1

Please sign in to comment.