Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/src/arrow/csv/options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ ConvertOptions ConvertOptions::Defaults() {
"NULL", "NaN", "n/a", "nan", "null"};
options.true_values = {"1", "True", "TRUE", "true"};
options.false_values = {"0", "False", "FALSE", "false"};
options.default_column_type = nullptr;
return options;
}

Expand Down
2 changes: 2 additions & 0 deletions cpp/src/arrow/csv/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ struct ARROW_EXPORT ConvertOptions {
bool check_utf8 = true;
/// Optional per-column types (disabling type inference on those columns)
std::unordered_map<std::string, std::shared_ptr<DataType>> column_types;
/// Default type to use for columns not in `column_types`
std::shared_ptr<DataType> default_column_type;
/// Recognized spellings for null values
std::vector<std::string> null_values;
/// Recognized spellings for boolean true values
Expand Down
11 changes: 9 additions & 2 deletions cpp/src/arrow/csv/reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -674,8 +674,15 @@ class ReaderMixin {
// Does the named column have a fixed type?
auto it = convert_options_.column_types.find(col_name);
if (it == convert_options_.column_types.end()) {
conversion_schema_.columns.push_back(
ConversionSchema::InferredColumn(std::move(col_name), col_index));
// If not explicitly typed, respect default_column_type when provided
if (convert_options_.default_column_type != nullptr) {
conversion_schema_.columns.push_back(ConversionSchema::TypedColumn(
std::move(col_name), col_index, convert_options_.default_column_type));
}
else {
conversion_schema_.columns.push_back(
ConversionSchema::InferredColumn(std::move(col_name), col_index));
}
} else {
conversion_schema_.columns.push_back(
ConversionSchema::TypedColumn(std::move(col_name), col_index, it->second));
Expand Down
133 changes: 133 additions & 0 deletions cpp/src/arrow/csv/reader_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -488,5 +488,138 @@ TEST(CountRowsAsync, Errors) {
internal::GetCpuThreadPool(), read_options, parse_options));
}

TEST(ReaderTests, DefaultColumnTypePartialDefault) {
// Input with header; force all unspecified columns to string, but override only `id` to int64
auto table_buffer = std::make_shared<Buffer>(
"id,name,value,date\n"
"0000101,apple,0003.1400,2024-01-15\n"
"00102,banana,001.6180,2024-02-20\n"
"0003,cherry,02.71800,2024-03-25\n");

auto input = std::make_shared<io::BufferReader>(table_buffer);
auto read_options = ReadOptions::Defaults();
auto parse_options = ParseOptions::Defaults();
auto convert_options = ConvertOptions::Defaults();
convert_options.column_types["id"] = int64();
convert_options.default_column_type = utf8();

ASSERT_OK_AND_ASSIGN(
auto reader,
TableReader::Make(io::default_io_context(), input, read_options, parse_options,
convert_options));
ASSERT_OK_AND_ASSIGN(auto table, reader->Read());

auto expected_schema =
schema({
field("id", int64()),
field("name", utf8()),
field("value", utf8()),
field("date", utf8())
});
AssertSchemaEqual(expected_schema, table->schema());

auto expected_table = TableFromJSON(
expected_schema,
{R"([{"id":101, "name":"apple", "value":"0003.1400", "date":"2024-01-15"},
{"id":102, "name":"banana", "value":"001.6180", "date":"2024-02-20"},
{"id":3, "name":"cherry", "value":"02.71800", "date":"2024-03-25"}])"});
ASSERT_TRUE(table->Equals(*expected_table));
}

TEST(ReaderTests, DefaultColumnTypeAllStringsWithHeader) {
// Input with header; default all columns to strings
auto table_buffer = std::make_shared<Buffer>(
"Record_Type|ID|Code|Quantity_1|Quantity_2|Amount_1|Amount_2|Amount_3|Flag|Note|Total_Amount\n"
"AB|000388907|abc|0150|012|000045.67|000000.10|000001.25|Y|noteA|000045.6700\n");

auto input = std::make_shared<io::BufferReader>(table_buffer);
auto read_options = ReadOptions::Defaults();
auto parse_options = ParseOptions::Defaults();
parse_options.delimiter = '|';
auto convert_options = ConvertOptions::Defaults();
convert_options.default_column_type = utf8();

ASSERT_OK_AND_ASSIGN(
auto reader,
TableReader::Make(io::default_io_context(), input, read_options, parse_options,
convert_options));
ASSERT_OK_AND_ASSIGN(auto table, reader->Read());

auto expected_schema = schema({
field("Record_Type", utf8()),
field("ID", utf8()),
field("Code", utf8()),
field("Quantity_1", utf8()),
field("Quantity_2", utf8()),
field("Amount_1", utf8()),
field("Amount_2", utf8()),
field("Amount_3", utf8()),
field("Flag", utf8()),
field("Note", utf8()),
field("Total_Amount", utf8())});
AssertSchemaEqual(expected_schema, table->schema());

auto expected_table = TableFromJSON(
expected_schema,
{R"([{
"Record_Type":"AB",
"ID":"000388907",
"Code":"abc",
"Quantity_1":"0150",
"Quantity_2":"012",
"Amount_1":"000045.67",
"Amount_2":"000000.10",
"Amount_3":"000001.25",
"Flag":"Y",
"Note":"noteA",
"Total_Amount":"000045.6700"
}])"});
ASSERT_TRUE(table->Equals(*expected_table));
}

TEST(ReaderTests, DefaultColumnTypeAllStringsNoHeader) {
// Input without header; autogenerate column names and default all to strings
auto table_buffer = std::make_shared<Buffer>(
"AB|000388907|abc|0150|012|000045.67|000000.10|000001.25|Y|noteA|000045.6700\n");

auto input = std::make_shared<io::BufferReader>(table_buffer);
auto read_options = ReadOptions::Defaults();
read_options.autogenerate_column_names = true; // treat first row as data
auto parse_options = ParseOptions::Defaults();
parse_options.delimiter = '|';
auto convert_options = ConvertOptions::Defaults();
convert_options.default_column_type = utf8();

ASSERT_OK_AND_ASSIGN(
auto reader,
TableReader::Make(io::default_io_context(), input, read_options, parse_options,
convert_options));
ASSERT_OK_AND_ASSIGN(auto table, reader->Read());

auto expected_schema = schema({
field("f0", utf8()), field("f1", utf8()), field("f2", utf8()),
field("f3", utf8()), field("f4", utf8()), field("f5", utf8()),
field("f6", utf8()), field("f7", utf8()), field("f8", utf8()),
field("f9", utf8()), field("f10", utf8())});
AssertSchemaEqual(expected_schema, table->schema());

auto expected_table = TableFromJSON(
expected_schema,
{R"([{
"f0":"AB",
"f1":"000388907",
"f2":"abc",
"f3":"0150",
"f4":"012",
"f5":"000045.67",
"f6":"000000.10",
"f7":"000001.25",
"f8":"Y",
"f9":"noteA",
"f10":"000045.6700"
}])"});
ASSERT_TRUE(table->Equals(*expected_table));
}

} // namespace csv
} // namespace arrow
1 change: 1 addition & 0 deletions docs/source/python/csv.rst
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ Available convert options are:

~ConvertOptions.check_utf8
~ConvertOptions.column_types
~ConvertOptions.default_column_type
~ConvertOptions.null_values
~ConvertOptions.true_values
~ConvertOptions.false_values
Expand Down
100 changes: 90 additions & 10 deletions python/pyarrow/_csv.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,9 @@ cdef class ConvertOptions(_Weakrefable):
column_types : pyarrow.Schema or dict, optional
Explicitly map column names to column types. Passing this argument
disables type inference on the defined columns.
default_column_type : pyarrow.DataType, optional
Explicitly map columns not specified in column_types to a default type.
Passing this argument disables type inference on all columns.
null_values : list, optional
A sequence of strings that denote nulls in the data
(defaults are appropriate in most cases). Note that by default,
Expand Down Expand Up @@ -807,6 +810,59 @@ cdef class ConvertOptions(_Weakrefable):
fast: bool
----
fast: [[true,true,false,false,null]]

Set a default column type for all columns (disables type inference):

>>> convert_options = csv.ConvertOptions(default_column_type=pa.string())
>>> csv.read_csv(io.BytesIO(s.encode()), convert_options=convert_options)
pyarrow.Table
animals: string
n_legs: string
entry: string
fast: string
----
animals: [["Flamingo","Horse","Brittle stars","Centipede",""]]
n_legs: [["2","4","5","100","6"]]
entry: [["01/03/2022","02/03/2022","03/03/2022","04/03/2022","05/03/2022"]]
fast: [["Yes","Yes","No","No",""]]

Combine default_column_type with column_types (specific column types override default):

>>> convert_options = csv.ConvertOptions(
... column_types={"n_legs": pa.int64(), "fast": pa.bool_()},
... default_column_type=pa.string(),
... true_values=["Yes"],
... false_values=["No"])
>>> csv.read_csv(io.BytesIO(s.encode()), convert_options=convert_options)
pyarrow.Table
animals: string
n_legs: int64
entry: string
fast: bool
----
animals: [["Flamingo","Horse","Brittle stars","Centipede",""]]
n_legs: [[2,4,5,100,6]]
entry: [["01/03/2022","02/03/2022","03/03/2022","04/03/2022","05/03/2022"]]
fast: [[true,true,false,false,null]]

Use default_column_type with selective column_types for mixed type conversion:

>>> convert_options = csv.ConvertOptions(
... column_types={"animals": pa.string(),
... "entry": pa.timestamp('s')},
... default_column_type=pa.string(),
... timestamp_parsers=["%m/%d/%Y"])
>>> csv.read_csv(io.BytesIO(s.encode()), convert_options=convert_options)
pyarrow.Table
animals: string
n_legs: string
entry: timestamp[s]
fast: string
----
animals: [["Flamingo","Horse","Brittle stars","Centipede",""]]
n_legs: [["2","4","5","100","6"]]
entry: [[2022-01-03 00:00:00,2022-02-03 00:00:00,2022-03-03 00:00:00,2022-04-03 00:00:00,2022-05-03 00:00:00]]
fast: [["Yes","Yes","No","No",""]]
"""

# Avoid mistakingly creating attributes
Expand All @@ -816,7 +872,7 @@ cdef class ConvertOptions(_Weakrefable):
self.options.reset(
new CCSVConvertOptions(CCSVConvertOptions.Defaults()))

def __init__(self, *, check_utf8=None, column_types=None, null_values=None,
def __init__(self, *, check_utf8=None, column_types=None, default_column_type=None, null_values=None,
true_values=None, false_values=None, decimal_point=None,
strings_can_be_null=None, quoted_strings_can_be_null=None,
include_columns=None, include_missing_columns=None,
Expand All @@ -826,6 +882,8 @@ cdef class ConvertOptions(_Weakrefable):
self.check_utf8 = check_utf8
if column_types is not None:
self.column_types = column_types
if default_column_type is not None:
self.default_column_type = default_column_type
if null_values is not None:
self.null_values = null_values
if true_values is not None:
Expand Down Expand Up @@ -910,6 +968,27 @@ cdef class ConvertOptions(_Weakrefable):
assert typ != NULL
deref(self.options).column_types[tobytes(k)] = typ

@property
def default_column_type(self):
"""
Explicitly map columns not specified in column_types to a default type.
"""
if deref(self.options).default_column_type != NULL:
return pyarrow_wrap_data_type(deref(self.options).default_column_type)
else:
return None

@default_column_type.setter
def default_column_type(self, value):
cdef:
shared_ptr[CDataType] typ
if value is not None:
typ = pyarrow_unwrap_data_type(ensure_type(value))
assert typ != NULL
deref(self.options).default_column_type = typ
else:
deref(self.options).default_column_type.reset()

@property
def null_values(self):
"""
Expand Down Expand Up @@ -1071,6 +1150,7 @@ cdef class ConvertOptions(_Weakrefable):
return (
self.check_utf8 == other.check_utf8 and
self.column_types == other.column_types and
self.default_column_type == other.default_column_type and
self.null_values == other.null_values and
self.true_values == other.true_values and
self.false_values == other.false_values and
Expand All @@ -1087,17 +1167,17 @@ cdef class ConvertOptions(_Weakrefable):
)

def __getstate__(self):
return (self.check_utf8, self.column_types, self.null_values,
self.true_values, self.false_values, self.decimal_point,
self.timestamp_parsers, self.strings_can_be_null,
self.quoted_strings_can_be_null, self.auto_dict_encode,
self.auto_dict_max_cardinality, self.include_columns,
self.include_missing_columns)
return (self.check_utf8, self.column_types, self.default_column_type,
self.null_values, self.true_values, self.false_values,
self.decimal_point, self.timestamp_parsers,
self.strings_can_be_null, self.quoted_strings_can_be_null,
self.auto_dict_encode, self.auto_dict_max_cardinality,
self.include_columns, self.include_missing_columns)

def __setstate__(self, state):
(self.check_utf8, self.column_types, self.null_values,
self.true_values, self.false_values, self.decimal_point,
self.timestamp_parsers, self.strings_can_be_null,
(self.check_utf8, self.column_types, self.default_column_type,
self.null_values, self.true_values, self.false_values,
self.decimal_point, self.timestamp_parsers, self.strings_can_be_null,
self.quoted_strings_can_be_null, self.auto_dict_encode,
self.auto_dict_max_cardinality, self.include_columns,
self.include_missing_columns) = state
Expand Down
1 change: 1 addition & 0 deletions python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -2104,6 +2104,7 @@ cdef extern from "arrow/csv/api.h" namespace "arrow::csv" nogil:
cdef cppclass CCSVConvertOptions" arrow::csv::ConvertOptions":
c_bool check_utf8
unordered_map[c_string, shared_ptr[CDataType]] column_types
shared_ptr[CDataType] default_column_type
vector[c_string] null_values
vector[c_string] true_values
vector[c_string] false_values
Expand Down
Loading