@@ -22,15 +22,22 @@ def WB(where: str, merge: bool, batch_size: int, batches: csp.ts[[pa.RecordBatch
22
22
data = write_record_batches (where , batches , {}, merge , batch_size )
23
23
24
24
25
+ def _concat_batches (batches : list [pa .RecordBatch ]) -> pa .RecordBatch :
26
+ combined_table = pa .Table .from_batches (batches ).combine_chunks ()
27
+ combined_batches = combined_table .to_batches ()
28
+ if len (combined_batches ) > 1 :
29
+ raise ValueError ("Not able to combine multiple record batches into one record batch" )
30
+ return combined_batches [0 ]
31
+
32
+
25
33
class TestArrow :
26
34
def make_record_batch (self , ts_col_name : str , row_size : int , ts : datetime ) -> pa .RecordBatch :
27
35
data = {
28
36
ts_col_name : pa .array ([ts ] * row_size , type = pa .timestamp ("ms" )),
29
37
"name" : pa .array ([chr (ord ("A" ) + idx % 26 ) for idx in range (row_size )]),
30
38
}
31
39
schema = pa .schema ([(ts_col_name , pa .timestamp ("ms" )), ("name" , pa .string ())])
32
- rb = pa .RecordBatch .from_pydict (data )
33
- return rb .cast (schema )
40
+ return pa .RecordBatch .from_pydict (data , schema = schema )
34
41
35
42
def make_data (self , ts_col_name : str , row_sizes : [int ], start : datetime = _STARTTIME , interval : int = 1 ):
36
43
res = [
@@ -100,7 +107,7 @@ def test_start_found(self, small_batches: bool, row_sizes: [int], row_sizes_prev
100
107
assert [len (r [1 ][0 ]) for r in results ["data" ]] == clean_row_sizes
101
108
assert [r [1 ][0 ] for r in results ["data" ]] == clean_rbs
102
109
103
- results = csp .run (G , "TsCol" , schema , [pa . concat_batches (full_rbs )], small_batches , starttime = dt_start - delta )
110
+ results = csp .run (G , "TsCol" , schema , [_concat_batches (full_rbs )], small_batches , starttime = dt_start - delta )
104
111
assert len (results ["data" ]) == len (clean_row_sizes )
105
112
assert [len (r [1 ][0 ]) for r in results ["data" ]] == clean_row_sizes
106
113
assert [r [1 ][0 ] for r in results ["data" ]] == clean_rbs
@@ -126,7 +133,7 @@ def test_split(self, small_batches: bool, row_sizes: [int], repeat: int, dt_coun
126
133
for idx , tup in enumerate (results ["data" ]):
127
134
assert tup [1 ] == rbs_indivs [idx ]
128
135
129
- results = csp .run (G , "TsCol" , schema , [pa . concat_batches (rbs_full )], small_batches , starttime = _STARTTIME )
136
+ results = csp .run (G , "TsCol" , schema , [_concat_batches (rbs_full )], small_batches , starttime = _STARTTIME )
130
137
assert len (results ["data" ]) == len (rbs_indivs )
131
138
for idx , tup in enumerate (results ["data" ]):
132
139
assert pa .Table .from_batches (tup [1 ]) == pa .Table .from_batches (rbs_indivs [idx ])
@@ -201,7 +208,7 @@ def test_write_record_batches_concat(self, row_sizes: [int], concat: bool):
201
208
if not concat :
202
209
rbs_ts_expected = [rb [0 ] for rb in rbs_ts ]
203
210
else :
204
- rbs_ts_expected = [pa . concat_batches (rbs_ts [0 ])]
211
+ rbs_ts_expected = [_concat_batches (rbs_ts [0 ])]
205
212
assert rbs_ts_expected == res .to_batches ()
206
213
207
214
def test_write_record_batches_batch_sizes (self ):
@@ -214,7 +221,7 @@ def test_write_record_batches_batch_sizes(self):
214
221
res = pq .read_table (temp_file .name )
215
222
orig = pa .Table .from_batches (rbs )
216
223
assert res .equals (orig )
217
- rbs_ts_expected = [pa . concat_batches (rbs [2 * i : 2 * i + 2 ]) for i in range (5 )]
224
+ rbs_ts_expected = [_concat_batches (rbs [2 * i : 2 * i + 2 ]) for i in range (5 )]
218
225
assert rbs_ts_expected == res .to_batches ()
219
226
220
227
row_sizes = [10 ] * 10
@@ -226,5 +233,5 @@ def test_write_record_batches_batch_sizes(self):
226
233
res = pq .read_table (temp_file .name )
227
234
orig = pa .Table .from_batches (rbs )
228
235
assert res .equals (orig )
229
- rbs_ts_expected = [pa . concat_batches (rbs [3 * i : 3 * i + 3 ]) for i in range (4 )]
236
+ rbs_ts_expected = [_concat_batches (rbs [3 * i : 3 * i + 3 ]) for i in range (4 )]
230
237
assert rbs_ts_expected == res .to_batches ()
0 commit comments