diff --git a/dlt/common/libs/pyarrow.py b/dlt/common/libs/pyarrow.py index a020dbbf8f..88c38d915a 100644 --- a/dlt/common/libs/pyarrow.py +++ b/dlt/common/libs/pyarrow.py @@ -4,7 +4,6 @@ Any, Dict, Mapping, - Tuple, Optional, Union, Callable, @@ -670,9 +669,16 @@ def row_tuples_to_arrow( arrow_schema = columns_to_arrow(columns, caps, tz) + def infer_first_non_null_type(idx: int) -> Any: + for row in rows: + value = row[idx] + if value is not None: + return type(value) + return type(None) + for idx in range(0, len(arrow_schema.names)): field = arrow_schema.field(idx) - py_type = type(rows[0][idx]) + py_type = infer_first_non_null_type(idx) # cast double / float ndarrays to decimals if type mismatch, looks like decimals and floats are often mixed up in dialects if pa.types.is_decimal(field.type) and issubclass(py_type, (str, float)): logger.warning( diff --git a/tests/sources/sql_database/test_arrow_helpers.py b/tests/sources/sql_database/test_arrow_helpers.py index abd063889c..1ff233b497 100644 --- a/tests/sources/sql_database/test_arrow_helpers.py +++ b/tests/sources/sql_database/test_arrow_helpers.py @@ -83,6 +83,39 @@ def test_row_tuples_to_arrow_unknown_types(all_unknown: bool) -> None: assert pa.types.is_list(result[7].type) +@pytest.mark.parametrize("all_unknown", [True, False]) +def test_row_tuples_to_arrow_detects_nullable_json(all_unknown: bool) -> None: + rows = [ + (1, None), + (2, {"ix": 2}), + ] + + columns = { + "int_col": {"name": "int_col", "data_type": "bigint", "nullable": False}, + "json_col": {"name": "json_col", "data_type": "json", "nullable": False}, + } + + if all_unknown: + for col in columns.values(): + col.pop("data_type", None) + + result = row_tuples_to_arrow(rows, columns=columns, tz="UTC") # type: ignore + + # Result is arrow table containing all columns in original order with correct types + assert result.num_columns == len(columns) + result_col_names = [f.name for f in result.schema] + expected_names = list(columns) + assert result_col_names == expected_names + + assert pa.types.is_int64(result[0].type) + + # FIXME Why isn't this always coerced to a string? + if all_unknown: + assert pa.types.is_struct(result[1].type) + else: + assert pa.types.is_string(result[1].type) + + pytest.importorskip("sqlalchemy", minversion="2.0")