Skip to content

Commit c609dfa

Browse files
authored
feat: allow passing a slice to and expression with the [] indexing (#1215)
* Allow passing a slice to and expression with the [] indexing * Update documentation * Add unit test covering expressions in slice
1 parent b325a38 commit c609dfa

File tree

3 files changed

+64
-3
lines changed

3 files changed

+64
-3
lines changed

docs/source/user-guide/common-operations/expressions.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,13 @@ approaches.
8282
Indexing an element of an array via ``[]`` starts at index 0 whereas
8383
:py:func:`~datafusion.functions.array_element` starts at index 1.
8484

85+
Starting in DataFusion 49.0.0 you can also create slices of array elements using
86+
slice syntax from Python.
87+
88+
.. ipython:: python
89+
90+
df.select(col("a")[1:3].alias("second_two_elements"))
91+
8592
To check if an array is empty, you can use the function :py:func:`datafusion.functions.array_empty` or `datafusion.functions.empty`.
8693
This function returns a boolean indicating whether the array is empty.
8794

python/datafusion/expr.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,17 +352,44 @@ def __invert__(self) -> Expr:
352352
"""Binary not (~)."""
353353
return Expr(self.expr.__invert__())
354354

355-
def __getitem__(self, key: str | int) -> Expr:
355+
def __getitem__(self, key: str | int | slice) -> Expr:
356356
"""Retrieve sub-object.
357357
358358
If ``key`` is a string, returns the subfield of the struct.
359359
If ``key`` is an integer, retrieves the element in the array. Note that the
360-
element index begins at ``0``, unlike `array_element` which begins at ``1``.
360+
element index begins at ``0``, unlike
361+
:py:func:`~datafusion.functions.array_element` which begins at ``1``.
362+
If ``key`` is a slice, returns an array that contains a slice of the
363+
original array. Similar to integer indexing, this follows Python convention
364+
where the index begins at ``0`` unlike
365+
:py:func:`~datafusion.functions.array_slice` which begins at ``1``.
361366
"""
362367
if isinstance(key, int):
363368
return Expr(
364369
functions_internal.array_element(self.expr, Expr.literal(key + 1).expr)
365370
)
371+
if isinstance(key, slice):
372+
if isinstance(key.start, int):
373+
start = Expr.literal(key.start + 1).expr
374+
elif isinstance(key.start, Expr):
375+
start = (key.start + Expr.literal(1)).expr
376+
else:
377+
# Default start at the first element, index 1
378+
start = Expr.literal(1).expr
379+
380+
if isinstance(key.stop, int):
381+
stop = Expr.literal(key.stop).expr
382+
else:
383+
stop = key.stop.expr
384+
385+
if isinstance(key.step, int):
386+
step = Expr.literal(key.step).expr
387+
elif isinstance(key.step, Expr):
388+
step = key.step.expr
389+
else:
390+
step = key.step
391+
392+
return Expr(functions_internal.array_slice(self.expr, start, stop, step))
366393
return Expr(self.expr.__getitem__(key))
367394

368395
def __eq__(self, rhs: object) -> Expr:

python/tests/test_functions.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,30 @@ def py_flatten(arr):
494494
lambda col: f.list_slice(col, literal(-1), literal(2)),
495495
lambda data: [arr[-1:2] for arr in data],
496496
),
497+
(
498+
lambda col: col[:3],
499+
lambda data: [arr[:3] for arr in data],
500+
),
501+
(
502+
lambda col: col[1:3],
503+
lambda data: [arr[1:3] for arr in data],
504+
),
505+
(
506+
lambda col: col[1:4:2],
507+
lambda data: [arr[1:4:2] for arr in data],
508+
),
509+
(
510+
lambda col: col[literal(1) : literal(4)],
511+
lambda data: [arr[1:4] for arr in data],
512+
),
513+
(
514+
lambda col: col[column("indices") : column("indices") + literal(2)],
515+
lambda data: [[2.0, 3.0], [], [6.0]],
516+
),
517+
(
518+
lambda col: col[literal(1) : literal(4) : literal(2)],
519+
lambda data: [arr[1:4:2] for arr in data],
520+
),
497521
(
498522
lambda col: f.array_intersect(col, literal([3.0, 4.0])),
499523
lambda data: [np.intersect1d(arr, [3.0, 4.0]) for arr in data],
@@ -534,8 +558,11 @@ def py_flatten(arr):
534558
)
535559
def test_array_functions(stmt, py_expr):
536560
data = [[1.0, 2.0, 3.0, 3.0], [4.0, 5.0, 3.0], [6.0]]
561+
indices = [1, 3, 0]
537562
ctx = SessionContext()
538-
batch = pa.RecordBatch.from_arrays([np.array(data, dtype=object)], names=["arr"])
563+
batch = pa.RecordBatch.from_arrays(
564+
[np.array(data, dtype=object), indices], names=["arr", "indices"]
565+
)
539566
df = ctx.create_dataframe([[batch]])
540567

541568
col = column("arr")

0 commit comments

Comments
 (0)