diff --git a/cpp/src/arrow/csv/options.cc b/cpp/src/arrow/csv/options.cc index 365b5646b66..52daa9c5fc6 100644 --- a/cpp/src/arrow/csv/options.cc +++ b/cpp/src/arrow/csv/options.cc @@ -43,6 +43,7 @@ ConvertOptions ConvertOptions::Defaults() { "NULL", "NaN", "n/a", "nan", "null"}; options.true_values = {"1", "True", "TRUE", "true"}; options.false_values = {"0", "False", "FALSE", "false"}; + options.default_column_type = nullptr; return options; } diff --git a/cpp/src/arrow/csv/options.h b/cpp/src/arrow/csv/options.h index 10e55bf838c..839550c3f0c 100644 --- a/cpp/src/arrow/csv/options.h +++ b/cpp/src/arrow/csv/options.h @@ -76,6 +76,8 @@ struct ARROW_EXPORT ConvertOptions { bool check_utf8 = true; /// Optional per-column types (disabling type inference on those columns) std::unordered_map> column_types; + /// Default type to use for columns not in `column_types` + std::shared_ptr default_column_type; /// Recognized spellings for null values std::vector null_values; /// Recognized spellings for boolean true values diff --git a/cpp/src/arrow/csv/reader.cc b/cpp/src/arrow/csv/reader.cc index 3c4e7e3da0c..4767626ae6c 100644 --- a/cpp/src/arrow/csv/reader.cc +++ b/cpp/src/arrow/csv/reader.cc @@ -674,8 +674,15 @@ class ReaderMixin { // Does the named column have a fixed type? auto it = convert_options_.column_types.find(col_name); if (it == convert_options_.column_types.end()) { - conversion_schema_.columns.push_back( - ConversionSchema::InferredColumn(std::move(col_name), col_index)); + // If not explicitly typed, respect default_column_type when provided + if (convert_options_.default_column_type != nullptr) { + conversion_schema_.columns.push_back(ConversionSchema::TypedColumn( + std::move(col_name), col_index, convert_options_.default_column_type)); + } + else { + conversion_schema_.columns.push_back( + ConversionSchema::InferredColumn(std::move(col_name), col_index)); + } } else { conversion_schema_.columns.push_back( ConversionSchema::TypedColumn(std::move(col_name), col_index, it->second)); diff --git a/cpp/src/arrow/csv/reader_test.cc b/cpp/src/arrow/csv/reader_test.cc index 57cc7d8efa5..4035bf88b29 100644 --- a/cpp/src/arrow/csv/reader_test.cc +++ b/cpp/src/arrow/csv/reader_test.cc @@ -488,5 +488,138 @@ TEST(CountRowsAsync, Errors) { internal::GetCpuThreadPool(), read_options, parse_options)); } +TEST(ReaderTests, DefaultColumnTypePartialDefault) { + // Input with header; force all unspecified columns to string, but override only `id` to int64 + auto table_buffer = std::make_shared( + "id,name,value,date\n" + "0000101,apple,0003.1400,2024-01-15\n" + "00102,banana,001.6180,2024-02-20\n" + "0003,cherry,02.71800,2024-03-25\n"); + + auto input = std::make_shared(table_buffer); + auto read_options = ReadOptions::Defaults(); + auto parse_options = ParseOptions::Defaults(); + auto convert_options = ConvertOptions::Defaults(); + convert_options.column_types["id"] = int64(); + convert_options.default_column_type = utf8(); + + ASSERT_OK_AND_ASSIGN( + auto reader, + TableReader::Make(io::default_io_context(), input, read_options, parse_options, + convert_options)); + ASSERT_OK_AND_ASSIGN(auto table, reader->Read()); + + auto expected_schema = + schema({ + field("id", int64()), + field("name", utf8()), + field("value", utf8()), + field("date", utf8()) + }); + AssertSchemaEqual(expected_schema, table->schema()); + + auto expected_table = TableFromJSON( + expected_schema, + {R"([{"id":101, "name":"apple", "value":"0003.1400", "date":"2024-01-15"}, + {"id":102, "name":"banana", "value":"001.6180", "date":"2024-02-20"}, + {"id":3, "name":"cherry", "value":"02.71800", "date":"2024-03-25"}])"}); + ASSERT_TRUE(table->Equals(*expected_table)); +} + +TEST(ReaderTests, DefaultColumnTypeAllStringsWithHeader) { + // Input with header; default all columns to strings + auto table_buffer = std::make_shared( + "Record_Type|ID|Code|Quantity_1|Quantity_2|Amount_1|Amount_2|Amount_3|Flag|Note|Total_Amount\n" + "AB|000388907|abc|0150|012|000045.67|000000.10|000001.25|Y|noteA|000045.6700\n"); + + auto input = std::make_shared(table_buffer); + auto read_options = ReadOptions::Defaults(); + auto parse_options = ParseOptions::Defaults(); + parse_options.delimiter = '|'; + auto convert_options = ConvertOptions::Defaults(); + convert_options.default_column_type = utf8(); + + ASSERT_OK_AND_ASSIGN( + auto reader, + TableReader::Make(io::default_io_context(), input, read_options, parse_options, + convert_options)); + ASSERT_OK_AND_ASSIGN(auto table, reader->Read()); + + auto expected_schema = schema({ + field("Record_Type", utf8()), + field("ID", utf8()), + field("Code", utf8()), + field("Quantity_1", utf8()), + field("Quantity_2", utf8()), + field("Amount_1", utf8()), + field("Amount_2", utf8()), + field("Amount_3", utf8()), + field("Flag", utf8()), + field("Note", utf8()), + field("Total_Amount", utf8())}); + AssertSchemaEqual(expected_schema, table->schema()); + + auto expected_table = TableFromJSON( + expected_schema, + {R"([{ + "Record_Type":"AB", + "ID":"000388907", + "Code":"abc", + "Quantity_1":"0150", + "Quantity_2":"012", + "Amount_1":"000045.67", + "Amount_2":"000000.10", + "Amount_3":"000001.25", + "Flag":"Y", + "Note":"noteA", + "Total_Amount":"000045.6700" + }])"}); + ASSERT_TRUE(table->Equals(*expected_table)); +} + +TEST(ReaderTests, DefaultColumnTypeAllStringsNoHeader) { + // Input without header; autogenerate column names and default all to strings + auto table_buffer = std::make_shared( + "AB|000388907|abc|0150|012|000045.67|000000.10|000001.25|Y|noteA|000045.6700\n"); + + auto input = std::make_shared(table_buffer); + auto read_options = ReadOptions::Defaults(); + read_options.autogenerate_column_names = true; // treat first row as data + auto parse_options = ParseOptions::Defaults(); + parse_options.delimiter = '|'; + auto convert_options = ConvertOptions::Defaults(); + convert_options.default_column_type = utf8(); + + ASSERT_OK_AND_ASSIGN( + auto reader, + TableReader::Make(io::default_io_context(), input, read_options, parse_options, + convert_options)); + ASSERT_OK_AND_ASSIGN(auto table, reader->Read()); + + auto expected_schema = schema({ + field("f0", utf8()), field("f1", utf8()), field("f2", utf8()), + field("f3", utf8()), field("f4", utf8()), field("f5", utf8()), + field("f6", utf8()), field("f7", utf8()), field("f8", utf8()), + field("f9", utf8()), field("f10", utf8())}); + AssertSchemaEqual(expected_schema, table->schema()); + + auto expected_table = TableFromJSON( + expected_schema, + {R"([{ + "f0":"AB", + "f1":"000388907", + "f2":"abc", + "f3":"0150", + "f4":"012", + "f5":"000045.67", + "f6":"000000.10", + "f7":"000001.25", + "f8":"Y", + "f9":"noteA", + "f10":"000045.6700" + }])"}); + ASSERT_TRUE(table->Equals(*expected_table)); +} + } // namespace csv } // namespace arrow diff --git a/docs/source/python/csv.rst b/docs/source/python/csv.rst index 5eb68e9ccdc..27b740cdfd7 100644 --- a/docs/source/python/csv.rst +++ b/docs/source/python/csv.rst @@ -136,6 +136,7 @@ Available convert options are: ~ConvertOptions.check_utf8 ~ConvertOptions.column_types + ~ConvertOptions.default_column_type ~ConvertOptions.null_values ~ConvertOptions.true_values ~ConvertOptions.false_values diff --git a/python/pyarrow/_csv.pyx b/python/pyarrow/_csv.pyx index ed9d20beb6b..ef84078f134 100644 --- a/python/pyarrow/_csv.pyx +++ b/python/pyarrow/_csv.pyx @@ -613,6 +613,9 @@ cdef class ConvertOptions(_Weakrefable): column_types : pyarrow.Schema or dict, optional Explicitly map column names to column types. Passing this argument disables type inference on the defined columns. + default_column_type : pyarrow.DataType, optional + Explicitly map columns not specified in column_types to a default type. + Passing this argument disables type inference on all columns. null_values : list, optional A sequence of strings that denote nulls in the data (defaults are appropriate in most cases). Note that by default, @@ -807,6 +810,59 @@ cdef class ConvertOptions(_Weakrefable): fast: bool ---- fast: [[true,true,false,false,null]] + + Set a default column type for all columns (disables type inference): + + >>> convert_options = csv.ConvertOptions(default_column_type=pa.string()) + >>> csv.read_csv(io.BytesIO(s.encode()), convert_options=convert_options) + pyarrow.Table + animals: string + n_legs: string + entry: string + fast: string + ---- + animals: [["Flamingo","Horse","Brittle stars","Centipede",""]] + n_legs: [["2","4","5","100","6"]] + entry: [["01/03/2022","02/03/2022","03/03/2022","04/03/2022","05/03/2022"]] + fast: [["Yes","Yes","No","No",""]] + + Combine default_column_type with column_types (specific column types override default): + + >>> convert_options = csv.ConvertOptions( + ... column_types={"n_legs": pa.int64(), "fast": pa.bool_()}, + ... default_column_type=pa.string(), + ... true_values=["Yes"], + ... false_values=["No"]) + >>> csv.read_csv(io.BytesIO(s.encode()), convert_options=convert_options) + pyarrow.Table + animals: string + n_legs: int64 + entry: string + fast: bool + ---- + animals: [["Flamingo","Horse","Brittle stars","Centipede",""]] + n_legs: [[2,4,5,100,6]] + entry: [["01/03/2022","02/03/2022","03/03/2022","04/03/2022","05/03/2022"]] + fast: [[true,true,false,false,null]] + + Use default_column_type with selective column_types for mixed type conversion: + + >>> convert_options = csv.ConvertOptions( + ... column_types={"animals": pa.string(), + ... "entry": pa.timestamp('s')}, + ... default_column_type=pa.string(), + ... timestamp_parsers=["%m/%d/%Y"]) + >>> csv.read_csv(io.BytesIO(s.encode()), convert_options=convert_options) + pyarrow.Table + animals: string + n_legs: string + entry: timestamp[s] + fast: string + ---- + animals: [["Flamingo","Horse","Brittle stars","Centipede",""]] + n_legs: [["2","4","5","100","6"]] + entry: [[2022-01-03 00:00:00,2022-02-03 00:00:00,2022-03-03 00:00:00,2022-04-03 00:00:00,2022-05-03 00:00:00]] + fast: [["Yes","Yes","No","No",""]] """ # Avoid mistakingly creating attributes @@ -816,7 +872,7 @@ cdef class ConvertOptions(_Weakrefable): self.options.reset( new CCSVConvertOptions(CCSVConvertOptions.Defaults())) - def __init__(self, *, check_utf8=None, column_types=None, null_values=None, + def __init__(self, *, check_utf8=None, column_types=None, default_column_type=None, null_values=None, true_values=None, false_values=None, decimal_point=None, strings_can_be_null=None, quoted_strings_can_be_null=None, include_columns=None, include_missing_columns=None, @@ -826,6 +882,8 @@ cdef class ConvertOptions(_Weakrefable): self.check_utf8 = check_utf8 if column_types is not None: self.column_types = column_types + if default_column_type is not None: + self.default_column_type = default_column_type if null_values is not None: self.null_values = null_values if true_values is not None: @@ -910,6 +968,27 @@ cdef class ConvertOptions(_Weakrefable): assert typ != NULL deref(self.options).column_types[tobytes(k)] = typ + @property + def default_column_type(self): + """ + Explicitly map columns not specified in column_types to a default type. + """ + if deref(self.options).default_column_type != NULL: + return pyarrow_wrap_data_type(deref(self.options).default_column_type) + else: + return None + + @default_column_type.setter + def default_column_type(self, value): + cdef: + shared_ptr[CDataType] typ + if value is not None: + typ = pyarrow_unwrap_data_type(ensure_type(value)) + assert typ != NULL + deref(self.options).default_column_type = typ + else: + deref(self.options).default_column_type.reset() + @property def null_values(self): """ @@ -1071,6 +1150,7 @@ cdef class ConvertOptions(_Weakrefable): return ( self.check_utf8 == other.check_utf8 and self.column_types == other.column_types and + self.default_column_type == other.default_column_type and self.null_values == other.null_values and self.true_values == other.true_values and self.false_values == other.false_values and @@ -1087,17 +1167,17 @@ cdef class ConvertOptions(_Weakrefable): ) def __getstate__(self): - return (self.check_utf8, self.column_types, self.null_values, - self.true_values, self.false_values, self.decimal_point, - self.timestamp_parsers, self.strings_can_be_null, - self.quoted_strings_can_be_null, self.auto_dict_encode, - self.auto_dict_max_cardinality, self.include_columns, - self.include_missing_columns) + return (self.check_utf8, self.column_types, self.default_column_type, + self.null_values, self.true_values, self.false_values, + self.decimal_point, self.timestamp_parsers, + self.strings_can_be_null, self.quoted_strings_can_be_null, + self.auto_dict_encode, self.auto_dict_max_cardinality, + self.include_columns, self.include_missing_columns) def __setstate__(self, state): - (self.check_utf8, self.column_types, self.null_values, - self.true_values, self.false_values, self.decimal_point, - self.timestamp_parsers, self.strings_can_be_null, + (self.check_utf8, self.column_types, self.default_column_type, + self.null_values, self.true_values, self.false_values, + self.decimal_point, self.timestamp_parsers, self.strings_can_be_null, self.quoted_strings_can_be_null, self.auto_dict_encode, self.auto_dict_max_cardinality, self.include_columns, self.include_missing_columns) = state diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index f294ee4d50b..fa479391211 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2104,6 +2104,7 @@ cdef extern from "arrow/csv/api.h" namespace "arrow::csv" nogil: cdef cppclass CCSVConvertOptions" arrow::csv::ConvertOptions": c_bool check_utf8 unordered_map[c_string, shared_ptr[CDataType]] column_types + shared_ptr[CDataType] default_column_type vector[c_string] null_values vector[c_string] true_values vector[c_string] false_values diff --git a/python/pyarrow/tests/test_csv.py b/python/pyarrow/tests/test_csv.py index f510c6dbe23..a4840bcb9f2 100644 --- a/python/pyarrow/tests/test_csv.py +++ b/python/pyarrow/tests/test_csv.py @@ -297,7 +297,8 @@ def test_convert_options(pickle_module): include_columns=['def', 'abc'], include_missing_columns=False, auto_dict_encode=True, - timestamp_parsers=[ISO8601, '%y-%m']) + timestamp_parsers=[ISO8601, '%y-%m'], + default_column_type=pa.int16()) with pytest.raises(ValueError): opts.decimal_point = '..' @@ -325,6 +326,17 @@ def test_convert_options(pickle_module): with pytest.raises(TypeError): opts.column_types = 0 + assert opts.default_column_type is None + opts.default_column_type = pa.string() + assert opts.default_column_type == pa.string() + opts.default_column_type = 'int32' + assert opts.default_column_type == pa.int32() + opts.default_column_type = None + assert opts.default_column_type is None + + with pytest.raises(TypeError, match='DataType expected'): + opts.default_column_type = 123 + assert isinstance(opts.null_values, list) assert '' in opts.null_values assert 'N/A' in opts.null_values @@ -1331,6 +1343,57 @@ def test_column_types_with_column_names(self): 'y': ['b', 'd', 'f'], } + def test_default_column_type(self): + rows = b"a,b,c,d\n001,2.5,hello,true\n4,3.14,world,false\n" + + # Test with default_column_type only - all columns should use the specified type. + opts = ConvertOptions(default_column_type=pa.string()) + table = self.read_bytes(rows, convert_options=opts) + schema = pa.schema([('a', pa.string()), + ('b', pa.string()), + ('c', pa.string()), + ('d', pa.string())]) + assert table.schema == schema + assert table.to_pydict() == { + 'a': ["001", "4"], + 'b': ["2.5", "3.14"], + 'c': ["hello", "world"], + 'd': ["true", "false"], + } + + # Test with both column_types and default_column_type + # Columns specified in column_types should override default_column_type + opts = ConvertOptions( + column_types={'b': pa.float64(), 'd': pa.bool_()}, + default_column_type=pa.string() + ) + table = self.read_bytes(rows, convert_options=opts) + schema = pa.schema([('a', pa.string()), + ('b', pa.float64()), + ('c', pa.string()), + ('d', pa.bool_())]) + assert table.schema == schema + assert table.to_pydict() == { + 'a': ["001", "4"], + 'b': [2.5, 3.14], + 'c': ["hello", "world"], + 'd': [True, False], + } + + # Test that default_column_type disables type inference + opts_no_default = ConvertOptions(column_types={'b': pa.float64()}) + table_no_default = self.read_bytes(rows, convert_options=opts_no_default) + + opts_with_default = ConvertOptions( + column_types={'b': pa.float64()}, + default_column_type=pa.string() + ) + table_with_default = self.read_bytes(rows, convert_options=opts_with_default) + + # Column 'a' should be int64 without default, string with default + assert table_no_default.schema.field('a').type == pa.int64() + assert table_with_default.schema.field('a').type == pa.string() + def test_no_ending_newline(self): # No \n after last line rows = b"a,b,c\n1,2,3\n4,5,6"