Skip to content

Commit f73da80

Browse files
kevinjqliusungwy
authored andcommitted
[bug] fix reading with to_arrow_batch_reader and limit (#1042)
* fix project_batches with limit * add test * lint + readability
1 parent 5e89fc5 commit f73da80

File tree

2 files changed

+54
-2
lines changed

2 files changed

+54
-2
lines changed

Diff for: pyiceberg/io/pyarrow.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1409,6 +1409,9 @@ def project_batches(
14091409
total_row_count = 0
14101410

14111411
for task in tasks:
1412+
# stop early if limit is satisfied
1413+
if limit is not None and total_row_count >= limit:
1414+
break
14121415
batches = _task_to_record_batches(
14131416
fs,
14141417
task,
@@ -1421,9 +1424,10 @@ def project_batches(
14211424
)
14221425
for batch in batches:
14231426
if limit is not None:
1424-
if total_row_count + len(batch) >= limit:
1425-
yield batch.slice(0, limit - total_row_count)
1427+
if total_row_count >= limit:
14261428
break
1429+
elif total_row_count + len(batch) >= limit:
1430+
batch = batch.slice(0, limit - total_row_count)
14271431
yield batch
14281432
total_row_count += len(batch)
14291433

Diff for: tests/integration/test_reads.py

+48
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,54 @@ def test_pyarrow_limit(catalog: Catalog) -> None:
240240
full_result = table_test_limit.scan(selected_fields=("idx",), limit=999).to_arrow()
241241
assert len(full_result) == 10
242242

243+
# test `to_arrow_batch_reader`
244+
limited_result = table_test_limit.scan(selected_fields=("idx",), limit=1).to_arrow_batch_reader().read_all()
245+
assert len(limited_result) == 1
246+
247+
empty_result = table_test_limit.scan(selected_fields=("idx",), limit=0).to_arrow_batch_reader().read_all()
248+
assert len(empty_result) == 0
249+
250+
full_result = table_test_limit.scan(selected_fields=("idx",), limit=999).to_arrow_batch_reader().read_all()
251+
assert len(full_result) == 10
252+
253+
254+
@pytest.mark.integration
255+
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
256+
def test_pyarrow_limit_with_multiple_files(catalog: Catalog) -> None:
257+
table_name = "default.test_pyarrow_limit_with_multiple_files"
258+
try:
259+
catalog.drop_table(table_name)
260+
except NoSuchTableError:
261+
pass
262+
reference_table = catalog.load_table("default.test_limit")
263+
data = reference_table.scan().to_arrow()
264+
table_test_limit = catalog.create_table(table_name, schema=reference_table.schema())
265+
266+
n_files = 2
267+
for _ in range(n_files):
268+
table_test_limit.append(data)
269+
assert len(table_test_limit.inspect.files()) == n_files
270+
271+
# test with multiple files
272+
limited_result = table_test_limit.scan(selected_fields=("idx",), limit=1).to_arrow()
273+
assert len(limited_result) == 1
274+
275+
empty_result = table_test_limit.scan(selected_fields=("idx",), limit=0).to_arrow()
276+
assert len(empty_result) == 0
277+
278+
full_result = table_test_limit.scan(selected_fields=("idx",), limit=999).to_arrow()
279+
assert len(full_result) == 10 * n_files
280+
281+
# test `to_arrow_batch_reader`
282+
limited_result = table_test_limit.scan(selected_fields=("idx",), limit=1).to_arrow_batch_reader().read_all()
283+
assert len(limited_result) == 1
284+
285+
empty_result = table_test_limit.scan(selected_fields=("idx",), limit=0).to_arrow_batch_reader().read_all()
286+
assert len(empty_result) == 0
287+
288+
full_result = table_test_limit.scan(selected_fields=("idx",), limit=999).to_arrow_batch_reader().read_all()
289+
assert len(full_result) == 10 * n_files
290+
243291

244292
@pytest.mark.integration
245293
@pytest.mark.filterwarnings("ignore")

0 commit comments

Comments
 (0)