Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for emitting sql jobs from a resource #2367

Draft
wants to merge 3 commits into
base: devel
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 55 additions & 1 deletion dlt/common/data_writers/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from dlt.common.libs.pyarrow import pyarrow as pa


TDataItemFormat = Literal["arrow", "object", "file"]
TDataItemFormat = Literal["arrow", "object", "file", "text"]
TWriter = TypeVar("TWriter", bound="DataWriter")


Expand Down Expand Up @@ -115,6 +115,8 @@ def item_format_from_file_extension(cls, extension: str) -> TDataItemFormat:
return "object"
elif extension == "parquet":
return "arrow"
elif extension == "sql":
return "text"
# those files may be imported by normalizer as is
elif extension in LOADER_FILE_FORMATS:
return "file"
Expand Down Expand Up @@ -175,6 +177,51 @@ def writer_spec(cls) -> FileWriterSpec:
)


class SQLWriter(DataWriter):
"""Writes incoming items row by row into a text file and ensures a trailing ;"""

def write_header(self, columns_schema: TTableSchemaColumns) -> None:
pass

def write_data(self, items: Sequence[TDataItem]) -> None:
super().write_data(items)
self.items_count += len(items)

for item in items:
if "value" in item:
item = item["value"]
item = item.strip()
if not item.endswith(";"):
item += ";"
self._f.write(item + "\n")

@classmethod
def writer_spec(cls) -> FileWriterSpec:
return FileWriterSpec(
"sql",
"text",
file_extension="sql",
is_binary_format=False,
supports_schema_changes="True",
supports_compression=False,
)


class SQLItemWriter(SQLWriter):
"""Writes incoming items row by row into a text file and ensures a trailing ;"""

@classmethod
def writer_spec(cls) -> FileWriterSpec:
return FileWriterSpec(
"sql",
"object",
file_extension="sql",
is_binary_format=False,
supports_schema_changes="True",
supports_compression=False,
)


class TypedJsonlListWriter(JsonlWriter):
def write_data(self, items: Sequence[TDataItem]) -> None:
# skip JsonlWriter when calling super
Expand Down Expand Up @@ -670,6 +717,8 @@ def is_native_writer(writer_type: Type[DataWriter]) -> bool:
ArrowToJsonlWriter,
ArrowToTypedJsonlListWriter,
ArrowToCsvWriter,
SQLWriter,
SQLItemWriter,
]

WRITER_SPECS: Dict[FileWriterSpec, Type[DataWriter]] = {
Expand All @@ -689,6 +738,11 @@ def is_native_writer(writer_type: Type[DataWriter]) -> bool:
for writer in ALL_WRITERS
if writer.writer_spec().data_item_format == "arrow" and is_native_writer(writer)
),
"text": tuple(
writer
for writer in ALL_WRITERS
if writer.writer_spec().data_item_format == "text" and is_native_writer(writer)
),
}


Expand Down
2 changes: 1 addition & 1 deletion dlt/common/destination/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def verify_supported_data_types(
for parsed_file in new_jobs:
formats = table_file_formats.setdefault(parsed_file.table_name, set())
if parsed_file.file_format in LOADER_FILE_FORMATS:
formats.add(parsed_file.file_format) # type: ignore[arg-type]
formats.add(parsed_file.file_format)
# all file formats
all_file_formats = set(capabilities.supported_loader_file_formats or []) | set(
capabilities.supported_staging_file_formats or []
Expand Down
2 changes: 1 addition & 1 deletion dlt/common/schema/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class TColumnPropInfo(NamedTuple):
if prop in COLUMN_HINTS:
ColumnPropInfos[prop] = ColumnPropInfos[prop]._replace(is_hint=True)

TTableFormat = Literal["iceberg", "delta", "hive", "native"]
TTableFormat = Literal["iceberg", "delta", "hive", "native", "view"]
TFileFormat = Literal[Literal["preferred"], TLoaderFileFormat]
TTypeDetections = Literal[
"timestamp", "iso_timestamp", "iso_date", "large_integer", "hexbytes_to_text", "wei_to_double"
Expand Down
4 changes: 3 additions & 1 deletion dlt/common/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,9 @@ class SecretSentinel:
VARIANT_FIELD_FORMAT = "v_%s"
TFileOrPath = Union[str, PathLike, IO[Any]]
TSortOrder = Literal["asc", "desc"]
TLoaderFileFormat = Literal["jsonl", "typed-jsonl", "insert_values", "parquet", "csv", "reference"]
TLoaderFileFormat = Literal[
"jsonl", "typed-jsonl", "insert_values", "parquet", "csv", "reference", "sql"
]
"""known loader file formats"""

TDynHintType = TypeVar("TDynHintType")
Expand Down
2 changes: 1 addition & 1 deletion dlt/destinations/impl/duckdb/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class duckdb(Destination[DuckDbClientConfiguration, "DuckDbClient"]):
def _raw_capabilities(self) -> DestinationCapabilitiesContext:
caps = DestinationCapabilitiesContext()
caps.preferred_loader_file_format = "insert_values"
caps.supported_loader_file_formats = ["insert_values", "parquet", "jsonl"]
caps.supported_loader_file_formats = ["insert_values", "parquet", "jsonl", "sql"]
caps.preferred_staging_file_format = None
caps.supported_staging_file_formats = []
caps.type_mapper = DuckDbTypeMapper
Expand Down
7 changes: 5 additions & 2 deletions dlt/extract/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from dlt.extract.reference import SourceReference
from dlt.extract.resource import DltResource
from dlt.extract.storage import ExtractStorage
from dlt.extract.extractors import ObjectExtractor, ArrowExtractor, Extractor
from dlt.extract.extractors import ObjectExtractor, ArrowExtractor, Extractor, TextExtractor
from dlt.extract.utils import get_data_item_format


Expand Down Expand Up @@ -343,6 +343,9 @@ def _extract_single_source(
"arrow": ArrowExtractor(
load_id, self.extract_storage.item_storages["arrow"], schema, collector=collector
),
"text": TextExtractor(
load_id, self.extract_storage.item_storages["text"], schema, collector=collector
),
}
# make sure we close storage on exception
with collector(f"Extract {source.name}"):
Expand All @@ -363,7 +366,7 @@ def _extract_single_source(
collector.update("Resources", delta)
signals.raise_if_signalled()
resource = source.resources[pipe_item.pipe.name]
item_format = get_data_item_format(pipe_item.item)
item_format = get_data_item_format(pipe_item.item, resource.file_format)
extractors[item_format].write_items(
resource, pipe_item.item, pipe_item.meta
)
Expand Down
6 changes: 6 additions & 0 deletions dlt/extract/extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,12 @@ class ObjectExtractor(Extractor):
pass


class TextExtractor(Extractor):
"""Extracts text items and writes them row by row into a text file"""

pass


class ArrowExtractor(Extractor):
"""Extracts arrow data items into parquet. Normalizes arrow items column names.
Compares the arrow schema to actual dlt table schema to reorder the columns and to
Expand Down
4 changes: 4 additions & 0 deletions dlt/extract/hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,10 @@ def schema_contract(self) -> TTableHintTemplate[TSchemaContract]:
def table_format(self) -> TTableHintTemplate[TTableFormat]:
return None if self._hints is None else self._hints.get("table_format")

@property
def file_format(self) -> TTableHintTemplate[TFileFormat]:
return None if self._hints is None else self._hints.get("file_format")

@property
def parent_table_name(self) -> TTableHintTemplate[str]:
return None if self._hints is None else self._hints.get("parent")
Expand Down
3 changes: 3 additions & 0 deletions dlt/extract/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ def __init__(self, config: NormalizeStorageConfiguration) -> None:
"arrow": ExtractorItemStorage(
self.new_packages, DataWriter.writer_spec_from_file_format("parquet", "arrow")
),
"text": ExtractorItemStorage(
self.new_packages, DataWriter.writer_spec_from_file_format("sql", "text")
),
}

def create_load_package(self, schema: Schema, reuse_exiting_package: bool = True) -> str:
Expand Down
12 changes: 11 additions & 1 deletion dlt/extract/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
SupportsPipe,
)

from dlt.common.schema.typing import TFileFormat

try:
from dlt.common.libs import pydantic
except MissingDependencyException:
Expand All @@ -61,14 +63,22 @@
pandas = None


def get_data_item_format(items: TDataItems) -> TDataItemFormat:
def get_data_item_format(
items: TDataItems, file_format: TTableHintTemplate[TFileFormat] = None
) -> TDataItemFormat:
"""Detect the format of the data item from `items`.

Reverts to `object` for empty lists

Returns:
The data file format.
"""

# if file format is specified as sql, we expect pure text from the resource
file_format = file_format(items) if callable(file_format) else file_format
if file_format == "sql":
return "text"

if not pyarrow and not pandas:
return "object"

Expand Down
2 changes: 1 addition & 1 deletion dlt/normalize/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def _get_items_normalizer(
if item_format == "file":
# if we want to import file, create a spec that may be used only for importing
best_writer_spec = create_import_spec(
parsed_file_name.file_format, items_supported_file_formats # type: ignore[arg-type]
parsed_file_name.file_format, items_supported_file_formats
)

config_loader_file_format = config.loader_file_format
Expand Down
41 changes: 41 additions & 0 deletions tests/load/test_sql_resource.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# test the sql insert job loader, works only on duckdb for now

from typing import Any

import dlt

from dlt.common.destination.dataset import SupportsReadableDataset

from tests.pipeline.utils import load_table_counts


def test_sql_job() -> None:
# populate a table with 10 items and retrieve dataset
pipeline = dlt.pipeline(
pipeline_name="example_pipeline", destination="duckdb", dataset_name="example_dataset"
)
pipeline.run([{"a": i} for i in range(10)], table_name="example_table")
dataset = pipeline.dataset()

# create a resource that generates sql statements to create 2 new tables
@dlt.resource(file_format="sql")
def copied_table() -> Any:
query = dataset["example_table"].limit(5).query()
yield f"CREATE OR REPLACE TABLE copied_table AS {query}"
query = dataset["example_table"].limit(7).query()
yield f"CREATE OR REPLACE TABLE copied_table2 AS {query}"

# run sql jobs
pipeline.run(copied_table())

# the two tables where created
assert load_table_counts(pipeline, "example_table", "copied_table", "copied_table2") == {
"example_table": 10,
"copied_table": 5,
"copied_table2": 7,
}

# we have a table entry for the main table "copied_table"
assert "copied_table" in pipeline.default_schema.tables
# but no columns, it's up to the user to provide a schema
assert len(pipeline.default_schema.tables["copied_table"]["columns"]) == 0
Loading