Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat/2372 extend dataset querying API for incremental load #2386

Open
wants to merge 13 commits into
base: devel
Choose a base branch
from
18 changes: 18 additions & 0 deletions dlt/common/schema/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,6 +837,24 @@ def get_root_table(tables: TSchemaTables, table_name: str) -> TTableSchema:
return table


def get_root_to_table_chain(tables: TSchemaTables, table_name: str) -> List[TTableSchema]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't think you need this anymore, do you?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll remove it for now, but we might want it if we want to support these joins when root_key=True is missing

"""Return a list of tables ordered from root to child using the (row_key - parent_key) references.

Similar functions:
- `get_root_table()` returns the root of the specified child instead of the full chain
- `get_nested_tables()` returns all children of the specified table as a chain
"""
chain: List[TTableSchema] = []

def _parent(t: TTableSchema) -> None:
chain.append(t)
if t.get("parent"):
_parent(tables[t["parent"]])

_parent(tables[table_name])
return chain[::-1]


def get_nested_tables(tables: TSchemaTables, table_name: str) -> List[TTableSchema]:
"""Get nested tables for table name and return a list of tables ordered by ancestry so the nested tables are always after their parents

Expand Down
128 changes: 121 additions & 7 deletions dlt/destinations/dataset/dataset.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
from typing import Any, Union, TYPE_CHECKING, List

from dlt.common.json import json

from dlt.common.exceptions import MissingDependencyException

from dlt.common.destination.reference import TDestinationReferenceArg, Destination
from dlt.common.destination.client import JobClientBase, WithStateSync
from dlt.common.destination.dataset import SupportsReadableRelation, SupportsReadableDataset
from dlt.common.destination.reference import TDestinationReferenceArg, Destination
from dlt.common.destination.typing import TDatasetType

from dlt.destinations.sql_client import SqlClientBase, WithSqlClient
from dlt.common.exceptions import MissingDependencyException
from dlt.common.json import json
from dlt.common.schema import Schema
from dlt.common.schema.utils import get_root_table
from dlt.destinations.dataset.relation import ReadableDBAPIRelation
from dlt.destinations.dataset.utils import get_destination_clients
from dlt.destinations.sql_client import SqlClientBase, WithSqlClient

if TYPE_CHECKING:
try:
Expand Down Expand Up @@ -121,6 +119,49 @@ def _ensure_client_and_schema(self) -> None:
def __call__(self, query: Any) -> ReadableDBAPIRelation:
return ReadableDBAPIRelation(readable_dataset=self, provided_query=query) # type: ignore[abstract]

def _filter_using_root_table(
self,
table_name: str,
key: str,
values_to_include: Union[list[Any], list[ReadableDBAPIRelation], ReadableDBAPIRelation],
) -> ReadableDBAPIRelation:
"""Filter the root table and propagate to current table.

Case 1: current table is root, then filter an return
Case 2: current table is not root, then join on root key

NOTE. The Ibis backend supports a 3rd case, where parent-child relationships are traversed
Such recursive query is difficult to express in raw SQL.
"""
# Case 1: if current table is root table, you can filter directly and exit early
dlt_root_table = get_root_table(self.schema.tables, table_name)
if dlt_root_table["name"] == self.table_name:
# TODO use secure SQL interpolation
# TODO use `sql_client` utils to use fully qualified names that respect destination
query = f"SELECT * FROM {table_name} WHERE {key} in {values_to_include}"
return self(query)

# Case 2: There is a root_key to join root with current table
# (e.g., root_key=True on Source, write_disposition="merge")
root_key = None
for col_name, col in self.columns_schema.items():
if col.get("root_key") is True:
root_key = col_name

if root_key is not None:
# TODO improve error handling
raise ValueError(f"Table `{table_name}` has no `root_key` to join with the load table.")

root_row_key = next(
col_name for col_name, col in dlt_root_table["columns"].items() if col.get("row_key") is True
)
query = (
"SELECT * "
f"FROM {table_name}"
f"WHERE current.{root_key} IN SELECT {root_row_key} FROM {dlt_root_table['name']})"
)
return self(query)

def table(self, table_name: str) -> SupportsReadableRelation:
# we can create an ibis powered relation if ibis is available
if table_name in self.schema.tables and self._dataset_type in ("auto", "ibis"):
Expand All @@ -133,6 +174,7 @@ def table(self, table_name: str) -> SupportsReadableRelation:
readable_dataset=self,
ibis_object=unbound_table,
columns_schema=self.schema.tables[table_name]["columns"],
table_name=table_name,
)
except MissingDependencyException:
# if ibis is explicitly requested, reraise
Expand Down Expand Up @@ -171,6 +213,78 @@ def row_counts(
# Execute query and build result dict
return self(query)

def list_load_ids(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see above

self, status: Union[int, list[int], None] = 0, limit: Union[int, None] = 10
) -> SupportsReadableRelation:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be implemented as ibis expression and raise if ibis is not installed

"""Return the list most recent `load_id`s in descending order.

If no `load_id` is found, return empty list.
"""
# TODO protect from SQL injection
query = "SELECT load_id FROM {self.schema.loads_table_name}"

if status is not None:
status_list = [status] if isinstance(status, int) else status
query += f"WHERE status IN {status_list}"

query += "ORDER BY load_id DESC"

if limit is not None:
query += f"LIMIT {limit}"

return self(query)

def latest_load_id(self, status: Union[int, list[int], None] = 0) -> SupportsReadableRelation:
"""Return the latest `load_id`."""
query = "SELECT max(load_id) FROM {self.schema.loads_table_name}"
if status is not None:
status_list = [status] if isinstance(status, int) else status
query += f"WHERE status IN {status_list}"

return self(query)

def filter_by_load_ids(self, table_name: str, load_ids: Union[str, list[str]]) -> ReadableDBAPIRelation:
"""Filter on matching `load_ids`."""
load_ids = [load_ids] if isinstance(load_ids, str) else load_ids
return self._filter_using_root_table(
table_name=table_name, key="_dlt_load_id", values_to_include=load_ids
)

def filter_by_load_status(
self, table_name: str, status: Union[int, list[int], None] = 0
) -> ReadableDBAPIRelation:
"""Filter to rows with a specific load status."""
if status is None:
return self

load_ids = [row[0] for row in self.list_load_ids(status=status).fetchall()]
return self._filter_using_root_table(
table_name=table_name, key="_dlt_load_id", values_to_include=load_ids
)

def filter_by_latest_load_id(
self, table_name: str, status: Union[int, list[int], None] = 0
) -> ReadableDBAPIRelation:
"""Filter on the most recent `load_id` with a specific load status.

If `status` is None, don't filter by status.
"""
latest_load_id = self.latest_load_id(status=status).fetchall()[0][0]
return self._filter_using_root_table(
table_name=table_name, key="_dlt_load_id", values_to_include=[latest_load_id]
)

def filter_by_load_id_gt(
self, table_name: str, load_id: str, status: Union[int, list[int], None] = 0
) -> ReadableDBAPIRelation:
load_ids = [
row[0] for row in self.list_load_ids(status=status, limit=None).fetchall()
if row[0] > load_id
]
return self._filter_using_root_table(
table_name=table_name, key="_dlt_load_id", values_to_include=load_ids
)

def __getitem__(self, table_name: str) -> SupportsReadableRelation:
"""access of table via dict notation"""
return self.table(table_name)
Expand Down
Loading
Loading