From e81d2026f70151dd24eb7b897bbacbfc4ed41374 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Fri, 26 Jun 2026 11:13:32 +0800 Subject: [PATCH 1/4] support nest arg in array_aggregate Signed-off-by: Haoyang Li --- .../python/higher_order_functions_test.py | 91 +++++++++++++++++++ .../nvidia/spark/rapids/GpuOverrides.scala | 2 +- 2 files changed, 92 insertions(+), 1 deletion(-) diff --git a/integration_tests/src/main/python/higher_order_functions_test.py b/integration_tests/src/main/python/higher_order_functions_test.py index 55151a3a562..751900e309d 100644 --- a/integration_tests/src/main/python/higher_order_functions_test.py +++ b/integration_tests/src/main/python/higher_order_functions_test.py @@ -176,6 +176,97 @@ def do_it(spark): assert_gpu_and_cpu_are_equal_collect(do_it) +@disable_ansi_mode +def test_array_aggregate_filtered_struct_with_nested_array_children(): + def do_it(spark): + data = [ + ([(1, 10, [1, 2], ["a"]), (2, 20, [3], ["b"])],), + ([(1, 30, [], []), (1, 40, [4, 5], ["c", "d"])],), + ([],), + (None,) + ] + schema = """ + a array, + labels:array>> + """ + return spark.createDataFrame(data, schema).selectExpr(""" + aggregate( + filter(a, ad -> ad.product_id = 1), + 0L, + (acc, ad) -> acc + CAST(ad.score AS BIGINT), + id -> id + ) as total_score""") + assert_gpu_and_cpu_are_equal_collect(do_it) + + +@disable_ansi_mode +def test_array_aggregate_filtered_struct_with_nested_map_children(): + def do_it(spark): + data = [ + ([(1, 10, {"a": 1, "b": 2}), (2, 20, {"c": 3})],), + ([(1, 30, {}), (1, 40, {"d": 4, "e": 5})],), + ([],), + (None,) + ] + schema = """ + a array>> + """ + return spark.createDataFrame(data, schema).selectExpr(""" + aggregate( + filter(a, ad -> ad.product_id = 1), + 0L, + (acc, ad) -> acc + (CAST(ad.score AS BIGINT) + CAST(size(ad.attrs) AS BIGINT)), + id -> id + ) as score_and_attr_count""") + assert_gpu_and_cpu_are_equal_collect(do_it) + + +@disable_ansi_mode +def test_array_aggregate_nested_filter_and_aggregate_over_struct_array_field(): + def do_it(spark): + data = [ + ([ + (1, (["10\tfoo\tIGN_ZTC_CPA_CPC", "20\tbar\tMISS"],), [1, 2]), + (2, (["30\tbaz\tIGN_ZTC_CPA_CPC"],), [3]) + ],), + ([ + (1, (["5\tfoo\t-", "", "7\tbar\tIGN_ZTC_CPA_CPC"],), []), + (1, ([],), [4, 5]) + ],), + ([],), + (None,) + ] + schema = """ + a array>, + unused_ids:array>> + """ + return spark.createDataFrame(data, schema).selectExpr(""" + aggregate( + filter(a, ad -> ad.product_id = 1 AND ad.im_ad_res_field IS NOT NULL), + 0L, + (acc, ad) -> acc + coalesce( + aggregate( + filter(ad.im_ad_res_field.charge_info, z -> z != ''), + 0L, + (acc2, z) -> acc2 + CASE WHEN ( + size(split(z, '\t', -1)) > 2 + AND split(z, '\t', -1)[2] IN ('-', 'IGN_ZTC_CPA_CPC') + ) THEN CAST(split(z, '\t', -1)[0] AS BIGINT) ELSE 0L END, + id -> id), + 0L), + id -> id + ) as charge_sum""") + assert_gpu_and_cpu_are_equal_collect(do_it) + + @disable_ansi_mode def test_array_aggregate_non_zero_init(): assert_gpu_and_cpu_are_equal_collect( diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index f4c763b9308..774b6134cd4 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -3008,7 +3008,7 @@ object GpuOverrides extends Logging { Seq( ParamCheck("argument", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + - TypeSig.BINARY + TypeSig.STRUCT), + TypeSig.BINARY + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP), TypeSig.ARRAY.nested(TypeSig.all)), ParamCheck("zero", TypeSig.commonCudfTypes + TypeSig.DECIMAL_128, From 61614bfd9f6a0ab9d2b355e2709c25fe5db4eeb7 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Fri, 26 Jun 2026 15:52:58 +0800 Subject: [PATCH 2/4] doc changes Signed-off-by: Haoyang Li --- docs/supported_ops.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/supported_ops.md b/docs/supported_ops.md index b97d230a6e8..de16a999f51 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -2469,7 +2469,7 @@ are limited. -PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types CALENDAR, ARRAY, MAP, UDT, DAYTIME, YEARMONTH
+PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types CALENDAR, UDT, DAYTIME, YEARMONTH
From b0646b3f312f700f8b14720e2ea1263966800b17 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Fri, 26 Jun 2026 15:59:54 +0800 Subject: [PATCH 3/4] integration tests clean up Signed-off-by: Haoyang Li --- .../python/higher_order_functions_test.py | 106 +++++++----------- 1 file changed, 41 insertions(+), 65 deletions(-) diff --git a/integration_tests/src/main/python/higher_order_functions_test.py b/integration_tests/src/main/python/higher_order_functions_test.py index 751900e309d..e5dc0fff90d 100644 --- a/integration_tests/src/main/python/higher_order_functions_test.py +++ b/integration_tests/src/main/python/higher_order_functions_test.py @@ -178,77 +178,54 @@ def do_it(spark): @disable_ansi_mode def test_array_aggregate_filtered_struct_with_nested_array_children(): - def do_it(spark): - data = [ - ([(1, 10, [1, 2], ["a"]), (2, 20, [3], ["b"])],), - ([(1, 30, [], []), (1, 40, [4, 5], ["c", "d"])],), - ([],), - (None,) - ] - schema = """ - a array, - labels:array>> - """ - return spark.createDataFrame(data, schema).selectExpr(""" - aggregate( - filter(a, ad -> ad.product_id = 1), - 0L, - (acc, ad) -> acc + CAST(ad.score AS BIGINT), - id -> id - ) as total_score""") - assert_gpu_and_cpu_are_equal_collect(do_it) + element_gen = StructGen([ + ('product_id', IntegerGen(nullable=False, min_val=0, max_val=2, special_cases=[0, 1, 2])), + ('score', IntegerGen(nullable=False, min_val=-100, max_val=100)), + ('nums', ArrayGen(IntegerGen(nullable=False), max_length=5, nullable=False)), + ('labels', ArrayGen(StringGen('[a-z]{1,3}', nullable=False), + max_length=5, nullable=False)) + ], nullable=False) + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, ArrayGen(element_gen, max_length=8)).selectExpr( + '''aggregate( + filter(a, ad -> ad.product_id = 1), + 0L, + (acc, ad) -> acc + CAST(ad.score AS BIGINT), + id -> id) as total_score''')) @disable_ansi_mode def test_array_aggregate_filtered_struct_with_nested_map_children(): - def do_it(spark): - data = [ - ([(1, 10, {"a": 1, "b": 2}), (2, 20, {"c": 3})],), - ([(1, 30, {}), (1, 40, {"d": 4, "e": 5})],), - ([],), - (None,) - ] - schema = """ - a array>> - """ - return spark.createDataFrame(data, schema).selectExpr(""" - aggregate( - filter(a, ad -> ad.product_id = 1), - 0L, - (acc, ad) -> acc + (CAST(ad.score AS BIGINT) + CAST(size(ad.attrs) AS BIGINT)), - id -> id - ) as score_and_attr_count""") - assert_gpu_and_cpu_are_equal_collect(do_it) + element_gen = StructGen([ + ('product_id', IntegerGen(nullable=False, min_val=0, max_val=2, special_cases=[0, 1, 2])), + ('score', IntegerGen(nullable=False, min_val=-100, max_val=100)), + ('attrs', MapGen(StringGen('key_[0-9]', nullable=False), + IntegerGen(nullable=False), max_length=5, nullable=False)) + ], nullable=False) + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, ArrayGen(element_gen, max_length=8)).selectExpr( + '''aggregate( + filter(a, ad -> ad.product_id = 1), + 0L, + (acc, ad) -> acc + + (CAST(ad.score AS BIGINT) + CAST(size(ad.attrs) AS BIGINT)), + id -> id) as score_and_attr_count''')) @disable_ansi_mode def test_array_aggregate_nested_filter_and_aggregate_over_struct_array_field(): - def do_it(spark): - data = [ - ([ - (1, (["10\tfoo\tIGN_ZTC_CPA_CPC", "20\tbar\tMISS"],), [1, 2]), - (2, (["30\tbaz\tIGN_ZTC_CPA_CPC"],), [3]) - ],), - ([ - (1, (["5\tfoo\t-", "", "7\tbar\tIGN_ZTC_CPA_CPC"],), []), - (1, ([],), [4, 5]) - ],), - ([],), - (None,) - ] - schema = """ - a array>, - unused_ids:array>> - """ - return spark.createDataFrame(data, schema).selectExpr(""" + charge_info_gen = StringGen( + '[0-9]{1,2}\t[a-z]{2}\t(IGN_ZTC_CPA_CPC|MISS|-)', nullable=False + ).with_special_case('', weight=5.0) + element_gen = StructGen([ + ('product_id', IntegerGen(nullable=False, min_val=0, max_val=2, special_cases=[0, 1, 2])), + ('im_ad_res_field', StructGen([ + ('charge_info', ArrayGen(charge_info_gen, max_length=5, nullable=False)) + ])), + ('unused_ids', ArrayGen(IntegerGen(nullable=False), max_length=5, nullable=False)) + ], nullable=False) + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, ArrayGen(element_gen, max_length=8)).selectExpr(""" aggregate( filter(a, ad -> ad.product_id = 1 AND ad.im_ad_res_field IS NOT NULL), 0L, @@ -263,8 +240,7 @@ def do_it(spark): id -> id), 0L), id -> id - ) as charge_sum""") - assert_gpu_and_cpu_are_equal_collect(do_it) + ) as charge_sum""")) @disable_ansi_mode From 55c3285822ba865bc762fc2e1e0f4211ef32ca0c Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Fri, 26 Jun 2026 17:32:09 +0800 Subject: [PATCH 4/4] add test coverage Signed-off-by: Haoyang Li --- .../src/main/python/higher_order_functions_test.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/integration_tests/src/main/python/higher_order_functions_test.py b/integration_tests/src/main/python/higher_order_functions_test.py index e5dc0fff90d..73e95359e32 100644 --- a/integration_tests/src/main/python/higher_order_functions_test.py +++ b/integration_tests/src/main/python/higher_order_functions_test.py @@ -212,6 +212,19 @@ def test_array_aggregate_filtered_struct_with_nested_map_children(): id -> id) as score_and_attr_count''')) +@pytest.mark.parametrize('element_gen', [ + ArrayGen(IntegerGen(nullable=False), max_length=5, nullable=False), + MapGen(StringGen('key_[0-9]', nullable=False), + IntegerGen(nullable=False), max_length=5, nullable=False) +], ids=['array-element', 'map-element']) +@disable_ansi_mode +def test_array_aggregate_direct_nested_collection_elements(element_gen): + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, ArrayGen(element_gen, max_length=8)).selectExpr( + 'aggregate(a, 0L, (acc, x) -> acc + CAST(size(x) AS BIGINT), id -> id) ' + 'as total_size')) + + @disable_ansi_mode def test_array_aggregate_nested_filter_and_aggregate_over_struct_array_field(): charge_info_gen = StringGen(