Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

synchronise pipeline protocol definitions and fix signature to use proper typing #2364

Open
wants to merge 4 commits into
base: devel
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 20 additions & 31 deletions dlt/common/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,14 @@
NamedTuple,
Optional,
Protocol,
Sequence,
Tuple,
TypeVar,
Mapping,
Literal,
)
from typing_extensions import NotRequired

from dlt.common.typing import TypedDict
from dlt.common.typing import TypedDict, Unpack
from dlt.common.configuration import configspec
from dlt.common.configuration import known_sections
from dlt.common.configuration.container import Container
Expand All @@ -46,7 +45,7 @@
NormalizeMetrics,
StepMetrics,
)
from dlt.common.schema import Schema
from dlt.common.schema import Schema, TAnySchemaColumns, TTableFormat
from dlt.common.schema.typing import (
TColumnSchema,
TWriteDispositionConfig,
Expand Down Expand Up @@ -474,6 +473,22 @@ class TSourceState(TPipelineState):
sources: Dict[str, Dict[str, Any]] # type: ignore[misc]


class SupportsPipelineRunArgs(TypedDict, total=False):
destination: Optional[TDestinationReferenceArg]
staging: Optional[TDestinationReferenceArg]
dataset_name: Optional[str]
credentials: Optional[Any]
table_name: Optional[str]
write_disposition: Optional[TWriteDispositionConfig]
columns: Optional[TAnySchemaColumns]
primary_key: Optional[TColumnNames]
schema: Optional[Schema]
loader_file_format: Optional[TLoaderFileFormat]
table_format: Optional[TTableFormat]
schema_contract: Optional[TSchemaContract]
refresh: Optional[TRefreshMode]


class SupportsPipeline(Protocol):
"""A protocol with core pipeline operations that lets high level abstractions ie. sources to access pipeline methods and properties"""

Expand Down Expand Up @@ -508,21 +523,7 @@ def set_local_state_val(self, key: str, value: Any) -> None:
def get_local_state_val(self, key: str) -> Any:
"""Gets value from local state. Local state is not synchronized with destination."""

def run(
self,
data: Any = None,
*,
destination: TDestinationReferenceArg = None,
dataset_name: str = None,
credentials: Any = None,
table_name: str = None,
write_disposition: TWriteDispositionConfig = None,
columns: Sequence[TColumnSchema] = None,
primary_key: TColumnNames = None,
schema: Schema = None,
loader_file_format: TLoaderFileFormat = None,
schema_contract: TSchemaContract = None,
) -> LoadInfo: ...
def run(self, data: Any = None, **kwargs: Unpack[SupportsPipelineRunArgs]) -> LoadInfo: ...

def _set_context(self, is_active: bool) -> None:
"""Called when pipeline context activated or deactivate"""
Expand All @@ -532,19 +533,7 @@ def _make_schema_with_default_name(self) -> Schema:


class SupportsPipelineRun(Protocol):
def __call__(
self,
*,
destination: TDestinationReferenceArg = None,
dataset_name: str = None,
credentials: Any = None,
table_name: str = None,
write_disposition: TWriteDispositionConfig = None,
columns: Sequence[TColumnSchema] = None,
schema: Schema = None,
loader_file_format: TLoaderFileFormat = None,
schema_contract: TSchemaContract = None,
) -> LoadInfo: ...
def __call__(self, data: Any = None, **kwargs: Unpack[SupportsPipelineRunArgs]) -> LoadInfo: ...


@configspec
Expand Down
4 changes: 4 additions & 0 deletions dlt/common/schema/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
TColumnHint,
TColumnSchema,
TColumnSchemaBase,
TAnySchemaColumns,
TTableFormat,
)
from dlt.common.schema.typing import COLUMN_HINTS
from dlt.common.schema.schema import Schema, DEFAULT_SCHEMA_CONTRACT_MODE
Expand All @@ -23,6 +25,8 @@
"TColumnHint",
"TColumnSchema",
"TColumnSchemaBase",
"TAnySchemaColumns",
"TTableFormat",
"COLUMN_HINTS",
"Schema",
"verify_schema_hash",
Expand Down
38 changes: 16 additions & 22 deletions dlt/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
ContextManager,
Union,
)
from typing_extensions import Unpack

import dlt
from dlt.common import logger
Expand Down Expand Up @@ -84,6 +85,7 @@
LoadInfo,
NormalizeInfo,
PipelineContext,
SupportsPipelineRunArgs,
TStepInfo,
SupportsPipeline,
TPipelineLocalState,
Expand Down Expand Up @@ -618,20 +620,7 @@ def load(
def run(
self,
data: Any = None,
*,
destination: TDestinationReferenceArg = None,
staging: TDestinationReferenceArg = None,
dataset_name: str = None,
credentials: Any = None,
table_name: str = None,
write_disposition: TWriteDispositionConfig = None,
columns: TAnySchemaColumns = None,
primary_key: TColumnNames = None,
schema: Schema = None,
loader_file_format: TLoaderFileFormat = None,
table_format: TTableFormat = None,
schema_contract: TSchemaContract = None,
refresh: Optional[TRefreshMode] = None,
**kwargs: Unpack[SupportsPipelineRunArgs],
) -> LoadInfo:
"""Loads the data from `data` argument into the destination specified in `destination` and dataset specified in `dataset_name`.

Expand Down Expand Up @@ -696,6 +685,11 @@ def run(
Returns:
LoadInfo: Information on loaded data including the list of package ids and failed job statuses. Please not that `dlt` will not raise if a single job terminally fails. Such information is provided via LoadInfo.
"""
destination = kwargs.get("destination", None)
credentials = kwargs.get("credentials", None)
staging = kwargs.get("staging", None)
dataset_name = kwargs.get("dataset_name", None)
loader_file_format = kwargs.get("loader_file_format", None)

signals.raise_if_signalled()
self.activate()
Expand Down Expand Up @@ -731,14 +725,14 @@ def run(
if data is not None:
self.extract(
data,
table_name=table_name,
write_disposition=write_disposition,
columns=columns,
primary_key=primary_key,
schema=schema,
table_format=table_format,
schema_contract=schema_contract,
refresh=refresh or self.refresh,
table_name=kwargs.get("table_name", None),
write_disposition=kwargs.get("write_disposition", None),
columns=kwargs.get("columns", None),
primary_key=kwargs.get("primary_key", None),
schema=kwargs.get("schema", None),
table_format=kwargs.get("table_format", None),
schema_contract=kwargs.get("schema_contract", None),
refresh=kwargs.get("refresh", self.refresh),
)
self.normalize(loader_file_format=loader_file_format)
return self.load(destination, dataset_name, credentials=credentials)
Expand Down