diff --git a/python/pyarrow/_csv.pyx b/python/pyarrow/_csv.pyx index 62cb75fa6ea..ed9d20beb6b 100644 --- a/python/pyarrow/_csv.pyx +++ b/python/pyarrow/_csv.pyx @@ -1364,13 +1364,16 @@ cdef class WriteOptions(_Weakrefable): - "none": do not enclose any values in quotes; values containing special characters (such as quotes, cell delimiters or line endings) will raise an error. + quoting_header : str, optional (default "needed") + Same as quoting_style, but for header column names. Accepts same values. + Note : both "needed" and "all_valid" have the same effect of quoting all column names. """ # Avoid mistakingly creating attributes __slots__ = () def __init__(self, *, include_header=None, batch_size=None, - delimiter=None, quoting_style=None): + delimiter=None, quoting_style=None, quoting_header=None): self.options.reset(new CCSVWriteOptions(CCSVWriteOptions.Defaults())) if include_header is not None: self.include_header = include_header @@ -1380,6 +1383,8 @@ cdef class WriteOptions(_Weakrefable): self.delimiter = delimiter if quoting_style is not None: self.quoting_style = quoting_style + if quoting_header is not None: + self.quoting_header = quoting_header @property def include_header(self): @@ -1433,6 +1438,18 @@ cdef class WriteOptions(_Weakrefable): def quoting_style(self, value): deref(self.options).quoting_style = unwrap_quoting_style(value) + @property + def quoting_header(self): + """ + Same as quoting_style, but for header column names. + Note : both "needed" and "all_valid" have the same effect of quoting all column names. + """ + return wrap_quoting_style(deref(self.options).quoting_header) + + @quoting_header.setter + def quoting_header(self, value): + deref(self.options).quoting_header = unwrap_quoting_style(value) + @staticmethod cdef WriteOptions wrap(CCSVWriteOptions options): out = WriteOptions() diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 39dc3a77d98..f294ee4d50b 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2147,6 +2147,7 @@ cdef extern from "arrow/csv/api.h" namespace "arrow::csv" nogil: int32_t batch_size unsigned char delimiter CQuotingStyle quoting_style + CQuotingStyle quoting_header CIOContext io_context CCSVWriteOptions() diff --git a/python/pyarrow/tests/test_csv.py b/python/pyarrow/tests/test_csv.py index 2794d07e87c..f510c6dbe23 100644 --- a/python/pyarrow/tests/test_csv.py +++ b/python/pyarrow/tests/test_csv.py @@ -2003,6 +2003,21 @@ def test_write_quoting_style(): buf.seek(0) +def test_write_quoting_header(): + t = pa.Table.from_arrays([[1, 2, None], ["a", None, "c"]], ["c1", "c2"]) + buf = io.BytesIO() + for write_options, res in [ + (WriteOptions(quoting_header='none'), b'c1,c2\n1,"a"\n2,\n,"c"\n'), + (WriteOptions(), b'"c1","c2"\n1,"a"\n2,\n,"c"\n'), + (WriteOptions(quoting_header='all_valid'), + b'"c1","c2"\n1,"a"\n2,\n,"c"\n'), + ]: + with CSVWriter(buf, t.schema, write_options=write_options) as writer: + writer.write_table(t) + assert buf.getvalue() == res + buf.seek(0) + + def test_read_csv_reference_cycle(): # ARROW-13187 def inner():