diff --git a/docs/supported_ops.md b/docs/supported_ops.md
index b97d230a6e8..de16a999f51 100644
--- a/docs/supported_ops.md
+++ b/docs/supported_ops.md
@@ -2469,7 +2469,7 @@ are limited.
|
|
|
-PS UTC is only supported TZ for child TIMESTAMP; unsupported child types CALENDAR, ARRAY, MAP, UDT, DAYTIME, YEARMONTH |
+PS UTC is only supported TZ for child TIMESTAMP; unsupported child types CALENDAR, UDT, DAYTIME, YEARMONTH |
|
|
|
diff --git a/integration_tests/src/main/python/higher_order_functions_test.py b/integration_tests/src/main/python/higher_order_functions_test.py
index 55151a3a562..73e95359e32 100644
--- a/integration_tests/src/main/python/higher_order_functions_test.py
+++ b/integration_tests/src/main/python/higher_order_functions_test.py
@@ -176,6 +176,86 @@ def do_it(spark):
assert_gpu_and_cpu_are_equal_collect(do_it)
+@disable_ansi_mode
+def test_array_aggregate_filtered_struct_with_nested_array_children():
+ element_gen = StructGen([
+ ('product_id', IntegerGen(nullable=False, min_val=0, max_val=2, special_cases=[0, 1, 2])),
+ ('score', IntegerGen(nullable=False, min_val=-100, max_val=100)),
+ ('nums', ArrayGen(IntegerGen(nullable=False), max_length=5, nullable=False)),
+ ('labels', ArrayGen(StringGen('[a-z]{1,3}', nullable=False),
+ max_length=5, nullable=False))
+ ], nullable=False)
+ assert_gpu_and_cpu_are_equal_collect(
+ lambda spark: unary_op_df(spark, ArrayGen(element_gen, max_length=8)).selectExpr(
+ '''aggregate(
+ filter(a, ad -> ad.product_id = 1),
+ 0L,
+ (acc, ad) -> acc + CAST(ad.score AS BIGINT),
+ id -> id) as total_score'''))
+
+
+@disable_ansi_mode
+def test_array_aggregate_filtered_struct_with_nested_map_children():
+ element_gen = StructGen([
+ ('product_id', IntegerGen(nullable=False, min_val=0, max_val=2, special_cases=[0, 1, 2])),
+ ('score', IntegerGen(nullable=False, min_val=-100, max_val=100)),
+ ('attrs', MapGen(StringGen('key_[0-9]', nullable=False),
+ IntegerGen(nullable=False), max_length=5, nullable=False))
+ ], nullable=False)
+ assert_gpu_and_cpu_are_equal_collect(
+ lambda spark: unary_op_df(spark, ArrayGen(element_gen, max_length=8)).selectExpr(
+ '''aggregate(
+ filter(a, ad -> ad.product_id = 1),
+ 0L,
+ (acc, ad) -> acc +
+ (CAST(ad.score AS BIGINT) + CAST(size(ad.attrs) AS BIGINT)),
+ id -> id) as score_and_attr_count'''))
+
+
+@pytest.mark.parametrize('element_gen', [
+ ArrayGen(IntegerGen(nullable=False), max_length=5, nullable=False),
+ MapGen(StringGen('key_[0-9]', nullable=False),
+ IntegerGen(nullable=False), max_length=5, nullable=False)
+], ids=['array-element', 'map-element'])
+@disable_ansi_mode
+def test_array_aggregate_direct_nested_collection_elements(element_gen):
+ assert_gpu_and_cpu_are_equal_collect(
+ lambda spark: unary_op_df(spark, ArrayGen(element_gen, max_length=8)).selectExpr(
+ 'aggregate(a, 0L, (acc, x) -> acc + CAST(size(x) AS BIGINT), id -> id) '
+ 'as total_size'))
+
+
+@disable_ansi_mode
+def test_array_aggregate_nested_filter_and_aggregate_over_struct_array_field():
+ charge_info_gen = StringGen(
+ '[0-9]{1,2}\t[a-z]{2}\t(IGN_ZTC_CPA_CPC|MISS|-)', nullable=False
+ ).with_special_case('', weight=5.0)
+ element_gen = StructGen([
+ ('product_id', IntegerGen(nullable=False, min_val=0, max_val=2, special_cases=[0, 1, 2])),
+ ('im_ad_res_field', StructGen([
+ ('charge_info', ArrayGen(charge_info_gen, max_length=5, nullable=False))
+ ])),
+ ('unused_ids', ArrayGen(IntegerGen(nullable=False), max_length=5, nullable=False))
+ ], nullable=False)
+ assert_gpu_and_cpu_are_equal_collect(
+ lambda spark: unary_op_df(spark, ArrayGen(element_gen, max_length=8)).selectExpr("""
+ aggregate(
+ filter(a, ad -> ad.product_id = 1 AND ad.im_ad_res_field IS NOT NULL),
+ 0L,
+ (acc, ad) -> acc + coalesce(
+ aggregate(
+ filter(ad.im_ad_res_field.charge_info, z -> z != ''),
+ 0L,
+ (acc2, z) -> acc2 + CASE WHEN (
+ size(split(z, '\t', -1)) > 2
+ AND split(z, '\t', -1)[2] IN ('-', 'IGN_ZTC_CPA_CPC')
+ ) THEN CAST(split(z, '\t', -1)[0] AS BIGINT) ELSE 0L END,
+ id -> id),
+ 0L),
+ id -> id
+ ) as charge_sum"""))
+
+
@disable_ansi_mode
def test_array_aggregate_non_zero_init():
assert_gpu_and_cpu_are_equal_collect(
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
index f4c763b9308..774b6134cd4 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
@@ -3008,7 +3008,7 @@ object GpuOverrides extends Logging {
Seq(
ParamCheck("argument",
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL +
- TypeSig.BINARY + TypeSig.STRUCT),
+ TypeSig.BINARY + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP),
TypeSig.ARRAY.nested(TypeSig.all)),
ParamCheck("zero",
TypeSig.commonCudfTypes + TypeSig.DECIMAL_128,