From 343833fea40ad73c7a8894bbbb990ab5edd53f04 Mon Sep 17 00:00:00 2001 From: dave Date: Mon, 3 Mar 2025 16:46:39 +0100 Subject: [PATCH] small fixes --- dlt/common/data_writers/writers.py | 27 +++++++-------------------- tests/load/test_sql_resource.py | 6 ++++++ 2 files changed, 13 insertions(+), 20 deletions(-) diff --git a/dlt/common/data_writers/writers.py b/dlt/common/data_writers/writers.py index 325fa802b7..e7b318ae48 100644 --- a/dlt/common/data_writers/writers.py +++ b/dlt/common/data_writers/writers.py @@ -115,6 +115,8 @@ def item_format_from_file_extension(cls, extension: str) -> TDataItemFormat: return "object" elif extension == "parquet": return "arrow" + elif extension == "sql": + return "text" # those files may be imported by normalizer as is elif extension in LOADER_FILE_FORMATS: return "file" @@ -182,8 +184,6 @@ def write_header(self, columns_schema: TTableSchemaColumns) -> None: pass def write_data(self, items: Sequence[TDataItem]) -> None: - # NOTE: is this too hacky? We take the first item and the value of the first item - # and interpret this as the sql query super().write_data(items) self.items_count += len(items) @@ -202,23 +202,6 @@ def writer_spec(cls) -> FileWriterSpec: ) -class SqlWriter(DataWriter): - pass - - -class SqlToTextWriter(SqlWriter, TextWriter): - @classmethod - def writer_spec(cls) -> FileWriterSpec: - return FileWriterSpec( - "sql", - "object", - file_extension="sql", - is_binary_format=False, - supports_schema_changes="True", - supports_compression=False, - ) - - class TypedJsonlListWriter(JsonlWriter): def write_data(self, items: Sequence[TDataItem]) -> None: # skip JsonlWriter when calling super @@ -715,7 +698,6 @@ def is_native_writer(writer_type: Type[DataWriter]) -> bool: ArrowToTypedJsonlListWriter, ArrowToCsvWriter, TextWriter, - SqlToTextWriter, ] WRITER_SPECS: Dict[FileWriterSpec, Type[DataWriter]] = { @@ -735,6 +717,11 @@ def is_native_writer(writer_type: Type[DataWriter]) -> bool: for writer in ALL_WRITERS if writer.writer_spec().data_item_format == "arrow" and is_native_writer(writer) ), + "text": tuple( + writer + for writer in ALL_WRITERS + if writer.writer_spec().data_item_format == "text" and is_native_writer(writer) + ), } diff --git a/tests/load/test_sql_resource.py b/tests/load/test_sql_resource.py index 9a95017e0f..c015aa0c12 100644 --- a/tests/load/test_sql_resource.py +++ b/tests/load/test_sql_resource.py @@ -28,8 +28,14 @@ def copied_table() -> Any: # run sql jobs pipeline.run(copied_table()) + # the two tables where created assert load_table_counts(pipeline, "example_table", "copied_table", "copied_table2") == { "example_table": 10, "copied_table": 5, "copied_table2": 7, } + + # we have a table entry for the main table "copied_table" + assert "copied_table" in pipeline.default_schema.tables + # but no columns, it's up to the user to provide a schema + assert len(pipeline.default_schema.tables["copied_table"]["columns"]) == 0