Skip to content

Commit

Permalink
add support for text item type and text to sql job writer
Browse files Browse the repository at this point in the history
  • Loading branch information
sh-rp committed Mar 3, 2025
1 parent c90bf91 commit 5aab2d2
Show file tree
Hide file tree
Showing 12 changed files with 111 additions and 9 deletions.
35 changes: 34 additions & 1 deletion dlt/common/data_writers/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from dlt.common.libs.pyarrow import pyarrow as pa


TDataItemFormat = Literal["arrow", "object", "file"]
TDataItemFormat = Literal["arrow", "object", "file", "text"]
TWriter = TypeVar("TWriter", bound="DataWriter")


Expand Down Expand Up @@ -115,6 +115,8 @@ def item_format_from_file_extension(cls, extension: str) -> TDataItemFormat:
return "object"
elif extension == "parquet":
return "arrow"
elif extension == "sql":
return "text"
# those files may be imported by normalizer as is
elif extension in LOADER_FILE_FORMATS:
return "file"
Expand Down Expand Up @@ -175,6 +177,31 @@ def writer_spec(cls) -> FileWriterSpec:
)


class TextWriter(DataWriter):
"""Writes incoming items row by row into a text file"""

def write_header(self, columns_schema: TTableSchemaColumns) -> None:
pass

def write_data(self, items: Sequence[TDataItem]) -> None:
super().write_data(items)
self.items_count += len(items)

for item in items:
self._f.write(item + "\n")

@classmethod
def writer_spec(cls) -> FileWriterSpec:
return FileWriterSpec(
"sql",
"text",
file_extension="sql",
is_binary_format=False,
supports_schema_changes="True",
supports_compression=False,
)


class TypedJsonlListWriter(JsonlWriter):
def write_data(self, items: Sequence[TDataItem]) -> None:
# skip JsonlWriter when calling super
Expand Down Expand Up @@ -670,6 +697,7 @@ def is_native_writer(writer_type: Type[DataWriter]) -> bool:
ArrowToJsonlWriter,
ArrowToTypedJsonlListWriter,
ArrowToCsvWriter,
TextWriter,
]

WRITER_SPECS: Dict[FileWriterSpec, Type[DataWriter]] = {
Expand All @@ -689,6 +717,11 @@ def is_native_writer(writer_type: Type[DataWriter]) -> bool:
for writer in ALL_WRITERS
if writer.writer_spec().data_item_format == "arrow" and is_native_writer(writer)
),
"text": tuple(
writer
for writer in ALL_WRITERS
if writer.writer_spec().data_item_format == "text" and is_native_writer(writer)
),
}


Expand Down
2 changes: 1 addition & 1 deletion dlt/common/destination/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def verify_supported_data_types(
for parsed_file in new_jobs:
formats = table_file_formats.setdefault(parsed_file.table_name, set())
if parsed_file.file_format in LOADER_FILE_FORMATS:
formats.add(parsed_file.file_format) # type: ignore[arg-type]
formats.add(parsed_file.file_format)
# all file formats
all_file_formats = set(capabilities.supported_loader_file_formats or []) | set(
capabilities.supported_staging_file_formats or []
Expand Down
2 changes: 1 addition & 1 deletion dlt/common/schema/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class TColumnPropInfo(NamedTuple):
if prop in COLUMN_HINTS:
ColumnPropInfos[prop] = ColumnPropInfos[prop]._replace(is_hint=True)

TTableFormat = Literal["iceberg", "delta", "hive", "native"]
TTableFormat = Literal["iceberg", "delta", "hive", "native", "view"]
TFileFormat = Literal[Literal["preferred"], TLoaderFileFormat]
TTypeDetections = Literal[
"timestamp", "iso_timestamp", "iso_date", "large_integer", "hexbytes_to_text", "wei_to_double"
Expand Down
4 changes: 3 additions & 1 deletion dlt/common/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,9 @@ class SecretSentinel:
VARIANT_FIELD_FORMAT = "v_%s"
TFileOrPath = Union[str, PathLike, IO[Any]]
TSortOrder = Literal["asc", "desc"]
TLoaderFileFormat = Literal["jsonl", "typed-jsonl", "insert_values", "parquet", "csv", "reference"]
TLoaderFileFormat = Literal[
"jsonl", "typed-jsonl", "insert_values", "parquet", "csv", "reference", "sql"
]
"""known loader file formats"""

TDynHintType = TypeVar("TDynHintType")
Expand Down
2 changes: 1 addition & 1 deletion dlt/destinations/impl/duckdb/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class duckdb(Destination[DuckDbClientConfiguration, "DuckDbClient"]):
def _raw_capabilities(self) -> DestinationCapabilitiesContext:
caps = DestinationCapabilitiesContext()
caps.preferred_loader_file_format = "insert_values"
caps.supported_loader_file_formats = ["insert_values", "parquet", "jsonl"]
caps.supported_loader_file_formats = ["insert_values", "parquet", "jsonl", "sql"]
caps.preferred_staging_file_format = None
caps.supported_staging_file_formats = []
caps.type_mapper = DuckDbTypeMapper
Expand Down
7 changes: 5 additions & 2 deletions dlt/extract/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from dlt.extract.reference import SourceReference
from dlt.extract.resource import DltResource
from dlt.extract.storage import ExtractStorage
from dlt.extract.extractors import ObjectExtractor, ArrowExtractor, Extractor
from dlt.extract.extractors import ObjectExtractor, ArrowExtractor, Extractor, TextExtractor
from dlt.extract.utils import get_data_item_format


Expand Down Expand Up @@ -343,6 +343,9 @@ def _extract_single_source(
"arrow": ArrowExtractor(
load_id, self.extract_storage.item_storages["arrow"], schema, collector=collector
),
"text": TextExtractor(
load_id, self.extract_storage.item_storages["text"], schema, collector=collector
),
}
# make sure we close storage on exception
with collector(f"Extract {source.name}"):
Expand All @@ -363,7 +366,7 @@ def _extract_single_source(
collector.update("Resources", delta)
signals.raise_if_signalled()
resource = source.resources[pipe_item.pipe.name]
item_format = get_data_item_format(pipe_item.item)
item_format = get_data_item_format(pipe_item.item, resource.file_format)
extractors[item_format].write_items(
resource, pipe_item.item, pipe_item.meta
)
Expand Down
6 changes: 6 additions & 0 deletions dlt/extract/extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,12 @@ class ObjectExtractor(Extractor):
pass


class TextExtractor(Extractor):
"""Extracts text items and writes them row by row into a text file"""

pass


class ArrowExtractor(Extractor):
"""Extracts arrow data items into parquet. Normalizes arrow items column names.
Compares the arrow schema to actual dlt table schema to reorder the columns and to
Expand Down
4 changes: 4 additions & 0 deletions dlt/extract/hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,10 @@ def schema_contract(self) -> TTableHintTemplate[TSchemaContract]:
def table_format(self) -> TTableHintTemplate[TTableFormat]:
return None if self._hints is None else self._hints.get("table_format")

@property
def file_format(self) -> TTableHintTemplate[TFileFormat]:
return None if self._hints is None else self._hints.get("file_format")

@property
def parent_table_name(self) -> TTableHintTemplate[str]:
return None if self._hints is None else self._hints.get("parent")
Expand Down
3 changes: 3 additions & 0 deletions dlt/extract/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ def __init__(self, config: NormalizeStorageConfiguration) -> None:
"arrow": ExtractorItemStorage(
self.new_packages, DataWriter.writer_spec_from_file_format("parquet", "arrow")
),
"text": ExtractorItemStorage(
self.new_packages, DataWriter.writer_spec_from_file_format("sql", "text")
),
}

def create_load_package(self, schema: Schema, reuse_exiting_package: bool = True) -> str:
Expand Down
12 changes: 11 additions & 1 deletion dlt/extract/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
SupportsPipe,
)

from dlt.common.schema.typing import TFileFormat

try:
from dlt.common.libs import pydantic
except MissingDependencyException:
Expand All @@ -61,14 +63,22 @@
pandas = None


def get_data_item_format(items: TDataItems) -> TDataItemFormat:
def get_data_item_format(
items: TDataItems, file_format: TTableHintTemplate[TFileFormat] = None
) -> TDataItemFormat:
"""Detect the format of the data item from `items`.
Reverts to `object` for empty lists
Returns:
The data file format.
"""

# if file format is specified as sql, we expect pure text from the resource
file_format = file_format(items) if callable(file_format) else file_format
if file_format == "sql":
return "text"

if not pyarrow and not pandas:
return "object"

Expand Down
2 changes: 1 addition & 1 deletion dlt/normalize/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def _get_items_normalizer(
if item_format == "file":
# if we want to import file, create a spec that may be used only for importing
best_writer_spec = create_import_spec(
parsed_file_name.file_format, items_supported_file_formats # type: ignore[arg-type]
parsed_file_name.file_format, items_supported_file_formats
)

config_loader_file_format = config.loader_file_format
Expand Down
41 changes: 41 additions & 0 deletions tests/load/test_sql_resource.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# test the sql insert job loader, works only on duckdb for now

from typing import Any

import dlt

from dlt.common.destination.dataset import SupportsReadableDataset

from tests.pipeline.utils import load_table_counts


def test_sql_job() -> None:
# populate a table with 10 items and retrieve dataset
pipeline = dlt.pipeline(
pipeline_name="example_pipeline", destination="duckdb", dataset_name="example_dataset"
)
pipeline.run([{"a": i} for i in range(10)], table_name="example_table")
dataset = pipeline.dataset()

# create a resource that generates sql statements to create 2 new tables
@dlt.resource(file_format="sql")
def copied_table() -> Any:
query = dataset["example_table"].limit(5).query()
yield f"CREATE OR REPLACE TABLE copied_table AS {query};"
query = dataset["example_table"].limit(7).query()
yield f"CREATE OR REPLACE TABLE copied_table2 AS {query};"

# run sql jobs
pipeline.run(copied_table())

# the two tables where created
assert load_table_counts(pipeline, "example_table", "copied_table", "copied_table2") == {
"example_table": 10,
"copied_table": 5,
"copied_table2": 7,
}

# we have a table entry for the main table "copied_table"
assert "copied_table" in pipeline.default_schema.tables
# but no columns, it's up to the user to provide a schema
assert len(pipeline.default_schema.tables["copied_table"]["columns"]) == 0

0 comments on commit 5aab2d2

Please sign in to comment.