Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions csp/adapters/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pyarrow as pa
import pyarrow.parquet as pq
from packaging.version import parse

import csp
from csp.impl.types.tstype import ts
Expand All @@ -14,6 +15,8 @@
"write_record_batches",
]

_PYARROW_HAS_CONCAT_BATCHES = parse(pa.__version__) >= parse("19.0.0")


CRecordBatchPullInputAdapter = input_adapter_def(
"CRecordBatchPullInputAdapter",
Expand Down Expand Up @@ -73,6 +76,18 @@ def RecordBatchPullInputAdapter(
)


def _concat_batches(batches: list[pa.RecordBatch]) -> pa.RecordBatch:
if _PYARROW_HAS_CONCAT_BATCHES:
# pyarrow version 19+ support concat_batches API
return pa.concat_batches(batches)
else:
combined_table = pa.Table.from_batches(batches).combine_chunks()
combined_batches = combined_table.to_batches()
if len(combined_batches) > 1:
raise ValueError("Not able to combine multiple record batches into one record batch")
return combined_batches[0]


@csp.node
def write_record_batches(
where: str,
Expand Down Expand Up @@ -102,12 +117,12 @@ def write_record_batches(
with csp.stop():
if s_writer:
if s_prev_batch:
s_writer.write_batch(pa.concat_batches(s_prev_batch))
s_writer.write_batch(_concat_batches(s_prev_batch))
s_writer.close()

if csp.ticked(batches):
if s_merge_batches:
batches = [pa.concat_batches(batches)]
batches = [_concat_batches(batches)]

for batch in batches:
if len(batch) == 0:
Expand All @@ -118,7 +133,7 @@ def write_record_batches(
s_prev_batch = [batch]
s_prev_batch_size = len(batch)
elif s_prev_batch_size + len(batch) > s_max_batch_size:
s_writer.write_batch(pa.concat_batches(s_prev_batch))
s_writer.write_batch(_concat_batches(s_prev_batch))
s_prev_batch = [batch]
s_prev_batch_size = len(batch)
else:
Expand Down
21 changes: 14 additions & 7 deletions csp/tests/adapters/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,22 @@ def WB(where: str, merge: bool, batch_size: int, batches: csp.ts[[pa.RecordBatch
data = write_record_batches(where, batches, {}, merge, batch_size)


def _concat_batches(batches: list[pa.RecordBatch]) -> pa.RecordBatch:
combined_table = pa.Table.from_batches(batches).combine_chunks()
combined_batches = combined_table.to_batches()
if len(combined_batches) > 1:
raise ValueError("Not able to combine multiple record batches into one record batch")
return combined_batches[0]


class TestArrow:
def make_record_batch(self, ts_col_name: str, row_size: int, ts: datetime) -> pa.RecordBatch:
data = {
ts_col_name: pa.array([ts] * row_size, type=pa.timestamp("ms")),
"name": pa.array([chr(ord("A") + idx % 26) for idx in range(row_size)]),
}
schema = pa.schema([(ts_col_name, pa.timestamp("ms")), ("name", pa.string())])
rb = pa.RecordBatch.from_pydict(data)
return rb.cast(schema)
return pa.RecordBatch.from_pydict(data, schema=schema)

def make_data(self, ts_col_name: str, row_sizes: [int], start: datetime = _STARTTIME, interval: int = 1):
res = [
Expand Down Expand Up @@ -100,7 +107,7 @@ def test_start_found(self, small_batches: bool, row_sizes: [int], row_sizes_prev
assert [len(r[1][0]) for r in results["data"]] == clean_row_sizes
assert [r[1][0] for r in results["data"]] == clean_rbs

results = csp.run(G, "TsCol", schema, [pa.concat_batches(full_rbs)], small_batches, starttime=dt_start - delta)
results = csp.run(G, "TsCol", schema, [_concat_batches(full_rbs)], small_batches, starttime=dt_start - delta)
assert len(results["data"]) == len(clean_row_sizes)
assert [len(r[1][0]) for r in results["data"]] == clean_row_sizes
assert [r[1][0] for r in results["data"]] == clean_rbs
Expand All @@ -126,7 +133,7 @@ def test_split(self, small_batches: bool, row_sizes: [int], repeat: int, dt_coun
for idx, tup in enumerate(results["data"]):
assert tup[1] == rbs_indivs[idx]

results = csp.run(G, "TsCol", schema, [pa.concat_batches(rbs_full)], small_batches, starttime=_STARTTIME)
results = csp.run(G, "TsCol", schema, [_concat_batches(rbs_full)], small_batches, starttime=_STARTTIME)
assert len(results["data"]) == len(rbs_indivs)
for idx, tup in enumerate(results["data"]):
assert pa.Table.from_batches(tup[1]) == pa.Table.from_batches(rbs_indivs[idx])
Expand Down Expand Up @@ -201,7 +208,7 @@ def test_write_record_batches_concat(self, row_sizes: [int], concat: bool):
if not concat:
rbs_ts_expected = [rb[0] for rb in rbs_ts]
else:
rbs_ts_expected = [pa.concat_batches(rbs_ts[0])]
rbs_ts_expected = [_concat_batches(rbs_ts[0])]
assert rbs_ts_expected == res.to_batches()

def test_write_record_batches_batch_sizes(self):
Expand All @@ -214,7 +221,7 @@ def test_write_record_batches_batch_sizes(self):
res = pq.read_table(temp_file.name)
orig = pa.Table.from_batches(rbs)
assert res.equals(orig)
rbs_ts_expected = [pa.concat_batches(rbs[2 * i : 2 * i + 2]) for i in range(5)]
rbs_ts_expected = [_concat_batches(rbs[2 * i : 2 * i + 2]) for i in range(5)]
assert rbs_ts_expected == res.to_batches()

row_sizes = [10] * 10
Expand All @@ -226,5 +233,5 @@ def test_write_record_batches_batch_sizes(self):
res = pq.read_table(temp_file.name)
orig = pa.Table.from_batches(rbs)
assert res.equals(orig)
rbs_ts_expected = [pa.concat_batches(rbs[3 * i : 3 * i + 3]) for i in range(4)]
rbs_ts_expected = [_concat_batches(rbs[3 * i : 3 * i + 3]) for i in range(4)]
assert rbs_ts_expected == res.to_batches()
Loading