diff --git a/bbq/vm/value_array.go b/bbq/vm/value_array.go index 4a9f4a0a3..ead292735 100644 --- a/bbq/vm/value_array.go +++ b/bbq/vm/value_array.go @@ -93,6 +93,18 @@ func init() { interpreter.NativeArrayMapFunction, ), ) + + registerBuiltinTypeBoundFunction( + typeQualifier, + NewNativeFunctionValueWithDerivedType( + sema.ArrayTypeReduceFunctionName, + func(receiver Value, context interpreter.ValueStaticTypeContext) *sema.FunctionType { + elementType := arrayElementTypeFromValue(receiver, context) + return sema.ArrayReduceFunctionType(elementType) + }, + interpreter.NativeArrayReduceFunction, + ), + ) } // Functions available only for variable-sized arrays. diff --git a/interpreter/array_test.go b/interpreter/array_test.go index 758721b85..01c67013e 100644 --- a/interpreter/array_test.go +++ b/interpreter/array_test.go @@ -24,6 +24,7 @@ import ( "github.com/stretchr/testify/require" "github.com/onflow/cadence/interpreter" + . "github.com/onflow/cadence/test_utils/interpreter_utils" ) func TestInterpretArrayFunctionEntitlements(t *testing.T) { @@ -155,3 +156,164 @@ func TestCheckArrayReferenceTypeInferenceWithDowncasting(t *testing.T) { var forceCastTypeMismatchError *interpreter.ForceCastTypeMismatchError require.ErrorAs(t, err, &forceCastTypeMismatchError) } + +func TestInterpretArrayReduce(t *testing.T) { + t.Parallel() + + t.Run("with variable sized array - sum", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndPrepare(t, ` + let xs = [1, 2, 3, 4, 5] + + let sum = + fun (acc: Int, x: Int): Int { + return acc + x + } + + fun reduce(): Int { + return xs.reduce(initial: 0, sum) + } + `) + + val, err := inter.Invoke("reduce") + require.NoError(t, err) + + AssertValuesEqual( + t, + inter, + interpreter.NewUnmeteredIntValueFromInt64(15), + val, + ) + }) + + t.Run("with variable sized array - product", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndPrepare(t, ` + let xs = [1, 2, 3, 4] + + let product = + fun (acc: Int, x: Int): Int { + return acc * x + } + + fun reduce(): Int { + return xs.reduce(initial: 1, product) + } + `) + + val, err := inter.Invoke("reduce") + require.NoError(t, err) + + AssertValuesEqual( + t, + inter, + interpreter.NewUnmeteredIntValueFromInt64(24), + val, + ) + }) + + t.Run("with empty array", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndPrepare(t, ` + let xs: [Int] = [] + + let sum = + fun (acc: Int, x: Int): Int { + return acc + x + } + + fun reduce(): Int { + return xs.reduce(initial: 42, sum) + } + `) + + val, err := inter.Invoke("reduce") + require.NoError(t, err) + + AssertValuesEqual( + t, + inter, + interpreter.NewUnmeteredIntValueFromInt64(42), + val, + ) + }) + + t.Run("with fixed sized array", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndPrepare(t, ` + let xs: [Int; 4] = [10, 20, 30, 40] + + let sum = + fun (acc: Int, x: Int): Int { + return acc + x + } + + fun reduce(): Int { + return xs.reduce(initial: 0, sum) + } + `) + + val, err := inter.Invoke("reduce") + require.NoError(t, err) + + AssertValuesEqual( + t, + inter, + interpreter.NewUnmeteredIntValueFromInt64(100), + val, + ) + }) + + t.Run("with type conversion", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndPrepare(t, ` + let xs = [1, 2, 3] + + let concat = + fun (acc: String, x: Int): String { + return acc.concat(x.toString()) + } + + fun reduce(): String { + return xs.reduce(initial: "", concat) + } + `) + + val, err := inter.Invoke("reduce") + require.NoError(t, err) + + AssertValuesEqual( + t, + inter, + interpreter.NewUnmeteredStringValue("123"), + val, + ) + }) + + t.Run("mutation during reduce", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndPrepare(t, ` + let xs = [1, 2, 3] + let xRef = &xs as auth(Mutate) &[Int] + let sum = + fun (acc: Int, x: Int): Int { + xRef.remove(at: 0) + return acc + x + } + + fun reduce(): Int { + return xRef.reduce(initial: 0, sum) + } + `) + + _, err := inter.Invoke("reduce") + var containerMutationError *interpreter.ContainerMutatedDuringIterationError + require.ErrorAs(t, err, &containerMutationError) + }) +} diff --git a/interpreter/value_array.go b/interpreter/value_array.go index cb5bba785..7470f88df 100644 --- a/interpreter/value_array.go +++ b/interpreter/value_array.go @@ -988,6 +988,16 @@ func (v *ArrayValue) GetMethod(context MemberAccessibleContext, name string) Fun NativeArrayMapFunction, ) + case sema.ArrayTypeReduceFunctionName: + return NewBoundHostFunctionValue( + context, + v, + sema.ArrayReduceFunctionType( + v.SemaType(context).ElementType(false), + ), + NativeArrayReduceFunction, + ) + case sema.ArrayTypeToVariableSizedFunctionName: return NewBoundHostFunctionValue( context, @@ -1696,6 +1706,48 @@ func (v *ArrayValue) Map( ) } +func (v *ArrayValue) Reduce( + context InvocationContext, + initial Value, + reducer FunctionValue, +) Value { + + elementType := v.SemaType(context).ElementType(false) + + reducerFunctionType := reducer.FunctionType(context) + parameterTypes := reducerFunctionType.ParameterTypes() + returnType := reducerFunctionType.ReturnTypeAnnotation.Type + + accumulator := initial + + v.Iterate( + context, + func(element Value) (resume bool) { + + // Meter computation for iterating the array. + common.UseComputation( + context, + common.LoopComputationUsage, + ) + + accumulator = invokeFunctionValue( + context, + reducer, + []Value{accumulator, element}, + []sema.Type{returnType, elementType}, + parameterTypes, + returnType, + nil, + ) + + return true + }, + false, + ) + + return accumulator +} + func (v *ArrayValue) ForEach( context IterableValueForeachContext, _ sema.Type, @@ -2049,6 +2101,21 @@ var NativeArrayMapFunction = NativeFunction( }, ) +var NativeArrayReduceFunction = NativeFunction( + func( + context NativeFunctionContext, + _ TypeArgumentsIterator, + receiver Value, + args []Value, + ) Value { + thisArray := AssertValueOfType[*ArrayValue](receiver) + initial := args[0] + funcValue := AssertValueOfType[FunctionValue](args[1]) + + return thisArray.Reduce(context, initial, funcValue) + }, +) + var NativeArrayToVariableSizedFunction = NativeFunction( func( context NativeFunctionContext, diff --git a/sema/arrays_dictionaries_test.go b/sema/arrays_dictionaries_test.go index 8ae10c8af..1ef528ff9 100644 --- a/sema/arrays_dictionaries_test.go +++ b/sema/arrays_dictionaries_test.go @@ -1329,6 +1329,155 @@ func TestCheckResourceArrayMapInvalid(t *testing.T) { assert.IsType(t, &sema.InvalidResourceArrayMemberError{}, errs[0]) } +func TestCheckArrayReduce(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, ` + fun test() { + let x = [1, 2, 3] + let sum = + fun (acc: Int, x: Int): Int { + return acc + x + } + + let y: Int = x.reduce(initial: 0, sum) + } + + fun testFixedSize() { + let x : [Int; 5] = [1, 2, 3, 21, 30] + let product = + fun (acc: Int, x: Int): Int { + return acc * x + } + + let y: Int = x.reduce(initial: 1, product) + } + + struct S { + var index: Int + + init(index: Int) { + self.index = index + } + } + + fun testStructArrayReduce() { + let structs = [S(index: 1), S(index: 2), S(index: 3)] + let ref = &structs as &[S] + + let reducer = + fun (acc: Int, x: S): Int { + x.index = x.index + 1 + return acc + x.index + } + + let y: Int = ref.reduce(initial: 0, reducer) + } + `) + + require.NoError(t, err) +} + +func TestCheckArrayReduceInvalidArgs(t *testing.T) { + + t.Parallel() + + testInvalidArgs := func(code string, expectedErrors []sema.SemanticError) { + _, err := ParseAndCheck(t, code) + + errs := RequireCheckerErrors(t, err, len(expectedErrors)) + + for i, e := range expectedErrors { + assert.IsType(t, e, errs[i]) + } + } + + testInvalidArgs(` + fun test() { + let x = [1, 2, 3] + let y = x.reduce(initial: 0, 100) + } + `, + []sema.SemanticError{ + &sema.TypeMismatchError{}, + }, + ) + + testInvalidArgs(` + fun test() { + let x = [1, 2, 3] + let sumInt16 = + fun (acc: Int16, x: Int16): Int16 { + return acc + x + } + + let y: Int = x.reduce(initial: 0, sumInt16) + } + `, + []sema.SemanticError{ + &sema.TypeMismatchError{}, + }, + ) + + testInvalidArgs(` + fun test() { + let x = [1, 2, 3] + let sum = + fun (acc: Int, x: Int): Int { + return acc + x + } + let y = x.reduce(initial: "0", sum) + } + `, + []sema.SemanticError{ + &sema.TypeMismatchError{}, + }, + ) + + testInvalidArgs(` + fun test() { + let x = [[1], [2]] + let reducer = + fun (acc: [Int], inner: auth(Mutate) &[Int]): [Int] { + inner.append(1) + return acc + } + let y = x.reduce(initial: [], reducer) + } + `, + []sema.SemanticError{ + &sema.TypeAnnotationRequiredError{}, + }, + ) +} + +func TestCheckResourceArrayReduceInvalid(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, ` + resource X {} + + fun test(): Int { + let xs <- [<-create X()] + let reducer = + fun (acc: Int, x: @X): Int { + destroy x + return acc + 1 + } + + let count: Int = xs.reduce(initial: 0, reducer) + destroy xs + return count + } + `) + + errs := RequireCheckerErrors(t, err, 1) + + assert.IsType(t, &sema.InvalidResourceArrayMemberError{}, errs[0]) +} + func TestCheckArrayContains(t *testing.T) { t.Parallel() diff --git a/sema/type.go b/sema/type.go index cba4330db..d14611afa 100644 --- a/sema/type.go +++ b/sema/type.go @@ -2260,6 +2260,12 @@ const arrayTypeMapFunctionDocString = ` Returns a new array whose elements are produced by applying the mapper function on each element of the original array. ` +const ArrayTypeReduceFunctionName = "reduce" + +const arrayTypeReduceFunctionDocString = ` +Reduces the array to a single value by calling the reducer function for each element, passing the accumulator and the element. +` + func getArrayMembers(arrayType ArrayType) map[string]MemberResolver { members := map[string]MemberResolver{ @@ -2453,6 +2459,35 @@ func getArrayMembers(arrayType ArrayType) map[string]MemberResolver { ) }, }, + ArrayTypeReduceFunctionName: { + Kind: common.DeclarationKindFunction, + Resolve: func( + memoryGauge common.MemoryGauge, + identifier string, + targetRange ast.HasPosition, + report func(error), + ) *Member { + elementType := arrayType.ElementType(false) + + if elementType.IsResourceType() { + report( + &InvalidResourceArrayMemberError{ + Name: identifier, + DeclarationKind: common.DeclarationKindFunction, + Range: ast.NewRangeFromPositioned(memoryGauge, targetRange), + }, + ) + } + + return NewPublicFunctionMember( + memoryGauge, + arrayType, + identifier, + ArrayReduceFunctionType(elementType), + arrayTypeReduceFunctionDocString, + ) + }, + }, } // TODO: maybe still return members but report a helpful error? @@ -3019,6 +3054,51 @@ func ArrayMapFunctionType(memoryGauge common.MemoryGauge, arrayType ArrayType) * } } +func ArrayReduceFunctionType(elementType Type) *FunctionType { + // fun reduce(initial: U, _ f: fun (U, T): U): U + + typeParameter := &TypeParameter{ + Name: "U", + } + + typeU := &GenericType{ + TypeParameter: typeParameter, + } + + // reducerFuncType: (U, T) -> U + reducerFuncType := &FunctionType{ + Parameters: []Parameter{ + { + Identifier: "accumulator", + TypeAnnotation: NewTypeAnnotation(typeU), + }, + { + Identifier: "element", + TypeAnnotation: NewTypeAnnotation(elementType), + }, + }, + ReturnTypeAnnotation: NewTypeAnnotation(typeU), + } + + return &FunctionType{ + TypeParameters: []*TypeParameter{ + typeParameter, + }, + Parameters: []Parameter{ + { + Identifier: "initial", + TypeAnnotation: NewTypeAnnotation(typeU), + }, + { + Label: ArgumentLabelNotRequired, + Identifier: "f", + TypeAnnotation: NewTypeAnnotation(reducerFuncType), + }, + }, + ReturnTypeAnnotation: NewTypeAnnotation(typeU), + } +} + // VariableSizedType is a variable sized array type type VariableSizedType struct { Type Type