Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions integration_tests/src/main/python/array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,21 @@ def do_it(spark):
assert_gpu_and_cpu_are_equal_collect(do_it)


@disable_ansi_mode
def test_array_heterogeneous_elementwise_hof_mixed_project():
data_gen = ArrayGen(IntegerGen(min_val=-10, max_val=10), max_length=8)
def do_it(spark):
return two_col_df(spark, data_gen, IntegerGen(min_val=-5, max_val=5)).selectExpr(
'a',
'b',
'transform(a, item -> item + b) as plus_b',
'transform(a, item -> item is not null and item >= b) as at_least_b',
'filter(a, item -> item is not null and item >= b) as filtered_at_least_b',
'exists(a, item -> item is not null and item < b) as has_less_than_b')

assert_gpu_and_cpu_are_equal_collect(do_it)


array_zips_gen = array_gens_sample + [ArrayGen(map_string_string_gen[0], max_length=5),
ArrayGen(BinaryGen(max_length=5), max_length=5)]

Expand Down
15 changes: 15 additions & 0 deletions integration_tests/src/main/python/higher_order_functions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,21 @@ def test_array_aggregate_count_if_int():
'aggregate(a, 0L, (acc, x) -> acc + CAST(CASE WHEN x IS NULL THEN 1 ELSE 0 END as BIGINT)) as null_cnt'))


@disable_ansi_mode
def test_array_hof_mixed_project_with_aggregate():
data_gen = ArrayGen(IntegerGen(min_val=-10, max_val=10), max_length=8)
def do_it(spark):
return unary_op_df(spark, data_gen).selectExpr(
'transform(a, x -> x + 1) as plus_one',
'filter(a, x -> x is not null and x >= 0) as non_negative',
'exists(a, x -> x is not null and x < 0) as has_negative',
'''aggregate(a, 0L,
(acc, x) -> acc + CAST(CASE WHEN x IS NULL THEN 0 ELSE x END AS BIGINT))
as sum_or_zero''')

assert_gpu_and_cpu_are_equal_collect(do_it)


Comment thread
thirtiseven marked this conversation as resolved.
# `if(cond, acc + t, acc)` shape — branches lifted via op identity. Same count-if
# pattern as above but written naturally instead of using `CASE WHEN ... THEN 1 ELSE 0`.
@disable_ansi_mode
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,10 @@ object GpuProjectExec {
// different vector length, thus not able to reuse cached vectors.
GpuExpressionsUtils.cachedNullVectors.get.clear()

val newColumns = boundExprs.safeMap(_.columnarEval(cb)).toArray[ColumnVector]
new ColumnarBatch(newColumns, cb.numRows())
GpuArrayTransformFusion.project(cb, boundExprs).getOrElse {
val newColumns = boundExprs.safeMap(_.columnarEval(cb)).toArray[ColumnVector]
new ColumnarBatch(newColumns, cb.numRows())
}
} finally {
GpuExpressionsUtils.cachedNullVectors.get.clear()
}
Expand Down
Loading