Skip to content

Commit f4460d3

Browse files
authored
Bugfix: Support arrow adapter for arrow versions <19 (#561)
Signed-off-by: Arham Chopra <[email protected]>
1 parent b65df16 commit f4460d3

File tree

2 files changed

+32
-10
lines changed

2 files changed

+32
-10
lines changed

csp/adapters/arrow.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import pyarrow as pa
44
import pyarrow.parquet as pq
5+
from packaging.version import parse
56

67
import csp
78
from csp.impl.types.tstype import ts
@@ -14,6 +15,8 @@
1415
"write_record_batches",
1516
]
1617

18+
_PYARROW_HAS_CONCAT_BATCHES = parse(pa.__version__) >= parse("19.0.0")
19+
1720

1821
CRecordBatchPullInputAdapter = input_adapter_def(
1922
"CRecordBatchPullInputAdapter",
@@ -73,6 +76,18 @@ def RecordBatchPullInputAdapter(
7376
)
7477

7578

79+
def _concat_batches(batches: list[pa.RecordBatch]) -> pa.RecordBatch:
80+
if _PYARROW_HAS_CONCAT_BATCHES:
81+
# pyarrow version 19+ support concat_batches API
82+
return pa.concat_batches(batches)
83+
else:
84+
combined_table = pa.Table.from_batches(batches).combine_chunks()
85+
combined_batches = combined_table.to_batches()
86+
if len(combined_batches) > 1:
87+
raise ValueError("Not able to combine multiple record batches into one record batch")
88+
return combined_batches[0]
89+
90+
7691
@csp.node
7792
def write_record_batches(
7893
where: str,
@@ -102,12 +117,12 @@ def write_record_batches(
102117
with csp.stop():
103118
if s_writer:
104119
if s_prev_batch:
105-
s_writer.write_batch(pa.concat_batches(s_prev_batch))
120+
s_writer.write_batch(_concat_batches(s_prev_batch))
106121
s_writer.close()
107122

108123
if csp.ticked(batches):
109124
if s_merge_batches:
110-
batches = [pa.concat_batches(batches)]
125+
batches = [_concat_batches(batches)]
111126

112127
for batch in batches:
113128
if len(batch) == 0:
@@ -118,7 +133,7 @@ def write_record_batches(
118133
s_prev_batch = [batch]
119134
s_prev_batch_size = len(batch)
120135
elif s_prev_batch_size + len(batch) > s_max_batch_size:
121-
s_writer.write_batch(pa.concat_batches(s_prev_batch))
136+
s_writer.write_batch(_concat_batches(s_prev_batch))
122137
s_prev_batch = [batch]
123138
s_prev_batch_size = len(batch)
124139
else:

csp/tests/adapters/test_arrow.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,22 @@ def WB(where: str, merge: bool, batch_size: int, batches: csp.ts[[pa.RecordBatch
2222
data = write_record_batches(where, batches, {}, merge, batch_size)
2323

2424

25+
def _concat_batches(batches: list[pa.RecordBatch]) -> pa.RecordBatch:
26+
combined_table = pa.Table.from_batches(batches).combine_chunks()
27+
combined_batches = combined_table.to_batches()
28+
if len(combined_batches) > 1:
29+
raise ValueError("Not able to combine multiple record batches into one record batch")
30+
return combined_batches[0]
31+
32+
2533
class TestArrow:
2634
def make_record_batch(self, ts_col_name: str, row_size: int, ts: datetime) -> pa.RecordBatch:
2735
data = {
2836
ts_col_name: pa.array([ts] * row_size, type=pa.timestamp("ms")),
2937
"name": pa.array([chr(ord("A") + idx % 26) for idx in range(row_size)]),
3038
}
3139
schema = pa.schema([(ts_col_name, pa.timestamp("ms")), ("name", pa.string())])
32-
rb = pa.RecordBatch.from_pydict(data)
33-
return rb.cast(schema)
40+
return pa.RecordBatch.from_pydict(data, schema=schema)
3441

3542
def make_data(self, ts_col_name: str, row_sizes: [int], start: datetime = _STARTTIME, interval: int = 1):
3643
res = [
@@ -100,7 +107,7 @@ def test_start_found(self, small_batches: bool, row_sizes: [int], row_sizes_prev
100107
assert [len(r[1][0]) for r in results["data"]] == clean_row_sizes
101108
assert [r[1][0] for r in results["data"]] == clean_rbs
102109

103-
results = csp.run(G, "TsCol", schema, [pa.concat_batches(full_rbs)], small_batches, starttime=dt_start - delta)
110+
results = csp.run(G, "TsCol", schema, [_concat_batches(full_rbs)], small_batches, starttime=dt_start - delta)
104111
assert len(results["data"]) == len(clean_row_sizes)
105112
assert [len(r[1][0]) for r in results["data"]] == clean_row_sizes
106113
assert [r[1][0] for r in results["data"]] == clean_rbs
@@ -126,7 +133,7 @@ def test_split(self, small_batches: bool, row_sizes: [int], repeat: int, dt_coun
126133
for idx, tup in enumerate(results["data"]):
127134
assert tup[1] == rbs_indivs[idx]
128135

129-
results = csp.run(G, "TsCol", schema, [pa.concat_batches(rbs_full)], small_batches, starttime=_STARTTIME)
136+
results = csp.run(G, "TsCol", schema, [_concat_batches(rbs_full)], small_batches, starttime=_STARTTIME)
130137
assert len(results["data"]) == len(rbs_indivs)
131138
for idx, tup in enumerate(results["data"]):
132139
assert pa.Table.from_batches(tup[1]) == pa.Table.from_batches(rbs_indivs[idx])
@@ -201,7 +208,7 @@ def test_write_record_batches_concat(self, row_sizes: [int], concat: bool):
201208
if not concat:
202209
rbs_ts_expected = [rb[0] for rb in rbs_ts]
203210
else:
204-
rbs_ts_expected = [pa.concat_batches(rbs_ts[0])]
211+
rbs_ts_expected = [_concat_batches(rbs_ts[0])]
205212
assert rbs_ts_expected == res.to_batches()
206213

207214
def test_write_record_batches_batch_sizes(self):
@@ -214,7 +221,7 @@ def test_write_record_batches_batch_sizes(self):
214221
res = pq.read_table(temp_file.name)
215222
orig = pa.Table.from_batches(rbs)
216223
assert res.equals(orig)
217-
rbs_ts_expected = [pa.concat_batches(rbs[2 * i : 2 * i + 2]) for i in range(5)]
224+
rbs_ts_expected = [_concat_batches(rbs[2 * i : 2 * i + 2]) for i in range(5)]
218225
assert rbs_ts_expected == res.to_batches()
219226

220227
row_sizes = [10] * 10
@@ -226,5 +233,5 @@ def test_write_record_batches_batch_sizes(self):
226233
res = pq.read_table(temp_file.name)
227234
orig = pa.Table.from_batches(rbs)
228235
assert res.equals(orig)
229-
rbs_ts_expected = [pa.concat_batches(rbs[3 * i : 3 * i + 3]) for i in range(4)]
236+
rbs_ts_expected = [_concat_batches(rbs[3 * i : 3 * i + 3]) for i in range(4)]
230237
assert rbs_ts_expected == res.to_batches()

0 commit comments

Comments
 (0)