Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -2469,7 +2469,7 @@ are limited.
<td> </td>
<td> </td>
<td> </td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types CALENDAR, ARRAY, MAP, UDT, DAYTIME, YEARMONTH</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types CALENDAR, UDT, DAYTIME, YEARMONTH</em></td>
<td> </td>
<td> </td>
<td> </td>
Expand Down
67 changes: 67 additions & 0 deletions integration_tests/src/main/python/higher_order_functions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,73 @@ def do_it(spark):
assert_gpu_and_cpu_are_equal_collect(do_it)


@disable_ansi_mode
def test_array_aggregate_filtered_struct_with_nested_array_children():
element_gen = StructGen([
('product_id', IntegerGen(nullable=False, min_val=0, max_val=2, special_cases=[0, 1, 2])),
('score', IntegerGen(nullable=False, min_val=-100, max_val=100)),
('nums', ArrayGen(IntegerGen(nullable=False), max_length=5, nullable=False)),
('labels', ArrayGen(StringGen('[a-z]{1,3}', nullable=False),
max_length=5, nullable=False))
], nullable=False)
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, ArrayGen(element_gen, max_length=8)).selectExpr(
'''aggregate(
filter(a, ad -> ad.product_id = 1),
0L,
(acc, ad) -> acc + CAST(ad.score AS BIGINT),
id -> id) as total_score'''))


@disable_ansi_mode
def test_array_aggregate_filtered_struct_with_nested_map_children():
element_gen = StructGen([
('product_id', IntegerGen(nullable=False, min_val=0, max_val=2, special_cases=[0, 1, 2])),
('score', IntegerGen(nullable=False, min_val=-100, max_val=100)),
('attrs', MapGen(StringGen('key_[0-9]', nullable=False),
IntegerGen(nullable=False), max_length=5, nullable=False))
], nullable=False)
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, ArrayGen(element_gen, max_length=8)).selectExpr(
'''aggregate(
filter(a, ad -> ad.product_id = 1),
0L,
(acc, ad) -> acc +
(CAST(ad.score AS BIGINT) + CAST(size(ad.attrs) AS BIGINT)),
id -> id) as score_and_attr_count'''))


@disable_ansi_mode
def test_array_aggregate_nested_filter_and_aggregate_over_struct_array_field():
charge_info_gen = StringGen(
'[0-9]{1,2}\t[a-z]{2}\t(IGN_ZTC_CPA_CPC|MISS|-)', nullable=False
).with_special_case('', weight=5.0)
element_gen = StructGen([
('product_id', IntegerGen(nullable=False, min_val=0, max_val=2, special_cases=[0, 1, 2])),
('im_ad_res_field', StructGen([
('charge_info', ArrayGen(charge_info_gen, max_length=5, nullable=False))
])),
('unused_ids', ArrayGen(IntegerGen(nullable=False), max_length=5, nullable=False))
], nullable=False)
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, ArrayGen(element_gen, max_length=8)).selectExpr("""
aggregate(
filter(a, ad -> ad.product_id = 1 AND ad.im_ad_res_field IS NOT NULL),
0L,
(acc, ad) -> acc + coalesce(
aggregate(
filter(ad.im_ad_res_field.charge_info, z -> z != ''),
0L,
(acc2, z) -> acc2 + CASE WHEN (
size(split(z, '\t', -1)) > 2
AND split(z, '\t', -1)[2] IN ('-', 'IGN_ZTC_CPA_CPC')
) THEN CAST(split(z, '\t', -1)[0] AS BIGINT) ELSE 0L END,
id -> id),
0L),
id -> id
) as charge_sum"""))


@disable_ansi_mode
def test_array_aggregate_non_zero_init():
assert_gpu_and_cpu_are_equal_collect(
Comment thread
greptile-apps[bot] marked this conversation as resolved.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3008,7 +3008,7 @@ object GpuOverrides extends Logging {
Seq(
ParamCheck("argument",
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL +
TypeSig.BINARY + TypeSig.STRUCT),
TypeSig.BINARY + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP),
TypeSig.ARRAY.nested(TypeSig.all)),
ParamCheck("zero",
TypeSig.commonCudfTypes + TypeSig.DECIMAL_128,
Expand Down
Loading