diff --git a/csp/adapters/arrow.py b/csp/adapters/arrow.py index e32a31c53..e01dd7106 100644 --- a/csp/adapters/arrow.py +++ b/csp/adapters/arrow.py @@ -2,6 +2,7 @@ import pyarrow as pa import pyarrow.parquet as pq +from packaging.version import parse import csp from csp.impl.types.tstype import ts @@ -14,6 +15,8 @@ "write_record_batches", ] +_PYARROW_HAS_CONCAT_BATCHES = parse(pa.__version__) >= parse("19.0.0") + CRecordBatchPullInputAdapter = input_adapter_def( "CRecordBatchPullInputAdapter", @@ -73,6 +76,18 @@ def RecordBatchPullInputAdapter( ) +def _concat_batches(batches: list[pa.RecordBatch]) -> pa.RecordBatch: + if _PYARROW_HAS_CONCAT_BATCHES: + # pyarrow version 19+ support concat_batches API + return pa.concat_batches(batches) + else: + combined_table = pa.Table.from_batches(batches).combine_chunks() + combined_batches = combined_table.to_batches() + if len(combined_batches) > 1: + raise ValueError("Not able to combine multiple record batches into one record batch") + return combined_batches[0] + + @csp.node def write_record_batches( where: str, @@ -102,12 +117,12 @@ def write_record_batches( with csp.stop(): if s_writer: if s_prev_batch: - s_writer.write_batch(pa.concat_batches(s_prev_batch)) + s_writer.write_batch(_concat_batches(s_prev_batch)) s_writer.close() if csp.ticked(batches): if s_merge_batches: - batches = [pa.concat_batches(batches)] + batches = [_concat_batches(batches)] for batch in batches: if len(batch) == 0: @@ -118,7 +133,7 @@ def write_record_batches( s_prev_batch = [batch] s_prev_batch_size = len(batch) elif s_prev_batch_size + len(batch) > s_max_batch_size: - s_writer.write_batch(pa.concat_batches(s_prev_batch)) + s_writer.write_batch(_concat_batches(s_prev_batch)) s_prev_batch = [batch] s_prev_batch_size = len(batch) else: diff --git a/csp/tests/adapters/test_arrow.py b/csp/tests/adapters/test_arrow.py index c2fdd8b37..8f59846ea 100644 --- a/csp/tests/adapters/test_arrow.py +++ b/csp/tests/adapters/test_arrow.py @@ -22,6 +22,14 @@ def WB(where: str, merge: bool, batch_size: int, batches: csp.ts[[pa.RecordBatch data = write_record_batches(where, batches, {}, merge, batch_size) +def _concat_batches(batches: list[pa.RecordBatch]) -> pa.RecordBatch: + combined_table = pa.Table.from_batches(batches).combine_chunks() + combined_batches = combined_table.to_batches() + if len(combined_batches) > 1: + raise ValueError("Not able to combine multiple record batches into one record batch") + return combined_batches[0] + + class TestArrow: def make_record_batch(self, ts_col_name: str, row_size: int, ts: datetime) -> pa.RecordBatch: data = { @@ -29,8 +37,7 @@ def make_record_batch(self, ts_col_name: str, row_size: int, ts: datetime) -> pa "name": pa.array([chr(ord("A") + idx % 26) for idx in range(row_size)]), } schema = pa.schema([(ts_col_name, pa.timestamp("ms")), ("name", pa.string())]) - rb = pa.RecordBatch.from_pydict(data) - return rb.cast(schema) + return pa.RecordBatch.from_pydict(data, schema=schema) def make_data(self, ts_col_name: str, row_sizes: [int], start: datetime = _STARTTIME, interval: int = 1): res = [ @@ -100,7 +107,7 @@ def test_start_found(self, small_batches: bool, row_sizes: [int], row_sizes_prev assert [len(r[1][0]) for r in results["data"]] == clean_row_sizes assert [r[1][0] for r in results["data"]] == clean_rbs - results = csp.run(G, "TsCol", schema, [pa.concat_batches(full_rbs)], small_batches, starttime=dt_start - delta) + results = csp.run(G, "TsCol", schema, [_concat_batches(full_rbs)], small_batches, starttime=dt_start - delta) assert len(results["data"]) == len(clean_row_sizes) assert [len(r[1][0]) for r in results["data"]] == clean_row_sizes assert [r[1][0] for r in results["data"]] == clean_rbs @@ -126,7 +133,7 @@ def test_split(self, small_batches: bool, row_sizes: [int], repeat: int, dt_coun for idx, tup in enumerate(results["data"]): assert tup[1] == rbs_indivs[idx] - results = csp.run(G, "TsCol", schema, [pa.concat_batches(rbs_full)], small_batches, starttime=_STARTTIME) + results = csp.run(G, "TsCol", schema, [_concat_batches(rbs_full)], small_batches, starttime=_STARTTIME) assert len(results["data"]) == len(rbs_indivs) for idx, tup in enumerate(results["data"]): assert pa.Table.from_batches(tup[1]) == pa.Table.from_batches(rbs_indivs[idx]) @@ -201,7 +208,7 @@ def test_write_record_batches_concat(self, row_sizes: [int], concat: bool): if not concat: rbs_ts_expected = [rb[0] for rb in rbs_ts] else: - rbs_ts_expected = [pa.concat_batches(rbs_ts[0])] + rbs_ts_expected = [_concat_batches(rbs_ts[0])] assert rbs_ts_expected == res.to_batches() def test_write_record_batches_batch_sizes(self): @@ -214,7 +221,7 @@ def test_write_record_batches_batch_sizes(self): res = pq.read_table(temp_file.name) orig = pa.Table.from_batches(rbs) assert res.equals(orig) - rbs_ts_expected = [pa.concat_batches(rbs[2 * i : 2 * i + 2]) for i in range(5)] + rbs_ts_expected = [_concat_batches(rbs[2 * i : 2 * i + 2]) for i in range(5)] assert rbs_ts_expected == res.to_batches() row_sizes = [10] * 10 @@ -226,5 +233,5 @@ def test_write_record_batches_batch_sizes(self): res = pq.read_table(temp_file.name) orig = pa.Table.from_batches(rbs) assert res.equals(orig) - rbs_ts_expected = [pa.concat_batches(rbs[3 * i : 3 * i + 3]) for i in range(4)] + rbs_ts_expected = [_concat_batches(rbs[3 * i : 3 * i + 3]) for i in range(4)] assert rbs_ts_expected == res.to_batches()