diff --git a/docs/supported_ops.md b/docs/supported_ops.md index b97d230a6e8..de16a999f51 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -2469,7 +2469,7 @@ are limited. -PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types CALENDAR, ARRAY, MAP, UDT, DAYTIME, YEARMONTH
+PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types CALENDAR, UDT, DAYTIME, YEARMONTH
diff --git a/integration_tests/src/main/python/higher_order_functions_test.py b/integration_tests/src/main/python/higher_order_functions_test.py index 55151a3a562..73e95359e32 100644 --- a/integration_tests/src/main/python/higher_order_functions_test.py +++ b/integration_tests/src/main/python/higher_order_functions_test.py @@ -176,6 +176,86 @@ def do_it(spark): assert_gpu_and_cpu_are_equal_collect(do_it) +@disable_ansi_mode +def test_array_aggregate_filtered_struct_with_nested_array_children(): + element_gen = StructGen([ + ('product_id', IntegerGen(nullable=False, min_val=0, max_val=2, special_cases=[0, 1, 2])), + ('score', IntegerGen(nullable=False, min_val=-100, max_val=100)), + ('nums', ArrayGen(IntegerGen(nullable=False), max_length=5, nullable=False)), + ('labels', ArrayGen(StringGen('[a-z]{1,3}', nullable=False), + max_length=5, nullable=False)) + ], nullable=False) + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, ArrayGen(element_gen, max_length=8)).selectExpr( + '''aggregate( + filter(a, ad -> ad.product_id = 1), + 0L, + (acc, ad) -> acc + CAST(ad.score AS BIGINT), + id -> id) as total_score''')) + + +@disable_ansi_mode +def test_array_aggregate_filtered_struct_with_nested_map_children(): + element_gen = StructGen([ + ('product_id', IntegerGen(nullable=False, min_val=0, max_val=2, special_cases=[0, 1, 2])), + ('score', IntegerGen(nullable=False, min_val=-100, max_val=100)), + ('attrs', MapGen(StringGen('key_[0-9]', nullable=False), + IntegerGen(nullable=False), max_length=5, nullable=False)) + ], nullable=False) + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, ArrayGen(element_gen, max_length=8)).selectExpr( + '''aggregate( + filter(a, ad -> ad.product_id = 1), + 0L, + (acc, ad) -> acc + + (CAST(ad.score AS BIGINT) + CAST(size(ad.attrs) AS BIGINT)), + id -> id) as score_and_attr_count''')) + + +@pytest.mark.parametrize('element_gen', [ + ArrayGen(IntegerGen(nullable=False), max_length=5, nullable=False), + MapGen(StringGen('key_[0-9]', nullable=False), + IntegerGen(nullable=False), max_length=5, nullable=False) +], ids=['array-element', 'map-element']) +@disable_ansi_mode +def test_array_aggregate_direct_nested_collection_elements(element_gen): + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, ArrayGen(element_gen, max_length=8)).selectExpr( + 'aggregate(a, 0L, (acc, x) -> acc + CAST(size(x) AS BIGINT), id -> id) ' + 'as total_size')) + + +@disable_ansi_mode +def test_array_aggregate_nested_filter_and_aggregate_over_struct_array_field(): + charge_info_gen = StringGen( + '[0-9]{1,2}\t[a-z]{2}\t(IGN_ZTC_CPA_CPC|MISS|-)', nullable=False + ).with_special_case('', weight=5.0) + element_gen = StructGen([ + ('product_id', IntegerGen(nullable=False, min_val=0, max_val=2, special_cases=[0, 1, 2])), + ('im_ad_res_field', StructGen([ + ('charge_info', ArrayGen(charge_info_gen, max_length=5, nullable=False)) + ])), + ('unused_ids', ArrayGen(IntegerGen(nullable=False), max_length=5, nullable=False)) + ], nullable=False) + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, ArrayGen(element_gen, max_length=8)).selectExpr(""" + aggregate( + filter(a, ad -> ad.product_id = 1 AND ad.im_ad_res_field IS NOT NULL), + 0L, + (acc, ad) -> acc + coalesce( + aggregate( + filter(ad.im_ad_res_field.charge_info, z -> z != ''), + 0L, + (acc2, z) -> acc2 + CASE WHEN ( + size(split(z, '\t', -1)) > 2 + AND split(z, '\t', -1)[2] IN ('-', 'IGN_ZTC_CPA_CPC') + ) THEN CAST(split(z, '\t', -1)[0] AS BIGINT) ELSE 0L END, + id -> id), + 0L), + id -> id + ) as charge_sum""")) + + @disable_ansi_mode def test_array_aggregate_non_zero_init(): assert_gpu_and_cpu_are_equal_collect( diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index f4c763b9308..774b6134cd4 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -3008,7 +3008,7 @@ object GpuOverrides extends Logging { Seq( ParamCheck("argument", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + - TypeSig.BINARY + TypeSig.STRUCT), + TypeSig.BINARY + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP), TypeSig.ARRAY.nested(TypeSig.all)), ParamCheck("zero", TypeSig.commonCudfTypes + TypeSig.DECIMAL_128,