diff --git a/runtime/parser/declaration_test.go b/runtime/parser/declaration_test.go index f50618c322..1107d6cf5a 100644 --- a/runtime/parser/declaration_test.go +++ b/runtime/parser/declaration_test.go @@ -1246,13 +1246,7 @@ func TestParseFunctionDeclaration(t *testing.T) { t.Parallel() - result, errs := ParseDeclarations( - nil, - []byte("fun foo < > () {}"), - Config{ - TypeParametersEnabled: true, - }, - ) + result, errs := testParseDeclarations("fun foo < > () {}") require.Empty(t, errs) utils.AssertEqualWithDiff(t, @@ -1295,13 +1289,7 @@ func TestParseFunctionDeclaration(t *testing.T) { t.Parallel() - result, errs := ParseDeclarations( - nil, - []byte("fun foo < A > () {}"), - Config{ - TypeParametersEnabled: true, - }, - ) + result, errs := testParseDeclarations("fun foo < A > () {}") require.Empty(t, errs) utils.AssertEqualWithDiff(t, @@ -1351,13 +1339,7 @@ func TestParseFunctionDeclaration(t *testing.T) { t.Parallel() - result, errs := ParseDeclarations( - nil, - []byte("fun foo < A , B : C > () {}"), - Config{ - TypeParametersEnabled: true, - }, - ) + result, errs := testParseDeclarations("fun foo < A , B : C > () {}") require.Empty(t, errs) utils.AssertEqualWithDiff(t, @@ -1418,34 +1400,11 @@ func TestParseFunctionDeclaration(t *testing.T) { ) }) - t.Run("with type parameters, disabled", func(t *testing.T) { - - t.Parallel() - - _, errs := testParseDeclarations("fun foo() {}") - - utils.AssertEqualWithDiff(t, - []error{ - &SyntaxError{ - Message: "expected '(' as start of parameter list, got '<'", - Pos: ast.Position{Offset: 7, Line: 1, Column: 7}, - }, - }, - errs, - ) - }) - t.Run("missing type parameter list end, enabled", func(t *testing.T) { t.Parallel() - _, errs := ParseDeclarations( - nil, - []byte("fun foo < "), - Config{ - TypeParametersEnabled: true, - }, - ) + _, errs := testParseDeclarations("fun foo < ") utils.AssertEqualWithDiff(t, []error{ @@ -1462,13 +1421,7 @@ func TestParseFunctionDeclaration(t *testing.T) { t.Parallel() - _, errs := ParseDeclarations( - nil, - []byte("fun foo < A B > () { } "), - Config{ - TypeParametersEnabled: true, - }, - ) + _, errs := testParseDeclarations("fun foo < A B > () { } ") utils.AssertEqualWithDiff(t, []error{ diff --git a/runtime/parser/function.go b/runtime/parser/function.go index 5b22a5c6f8..28f694f1f0 100644 --- a/runtime/parser/function.go +++ b/runtime/parser/function.go @@ -313,12 +313,10 @@ func parseFunctionDeclaration( var typeParameterList *ast.TypeParameterList - if p.config.TypeParametersEnabled { - var err error - typeParameterList, err = parseTypeParameterList(p) - if err != nil { - return nil, err - } + var err error + typeParameterList, err = parseTypeParameterList(p) + if err != nil { + return nil, err } parameterList, returnTypeAnnotation, functionBlock, err := diff --git a/runtime/parser/parser.go b/runtime/parser/parser.go index 869e53c69c..edad4dfb3c 100644 --- a/runtime/parser/parser.go +++ b/runtime/parser/parser.go @@ -46,8 +46,6 @@ type Config struct { StaticModifierEnabled bool // NativeModifierEnabled determines if the native modifier is enabled NativeModifierEnabled bool - // TypeParametersEnabled determines if type parameters are enabled - TypeParametersEnabled bool } type parser struct { diff --git a/runtime/parser/statement.go b/runtime/parser/statement.go index 9db6281681..be1156f646 100644 --- a/runtime/parser/statement.go +++ b/runtime/parser/statement.go @@ -167,12 +167,10 @@ func parseFunctionDeclarationOrFunctionExpressionStatement(p *parser) (ast.State var typeParameterList *ast.TypeParameterList - if p.config.TypeParametersEnabled { - var err error - typeParameterList, err = parseTypeParameterList(p) - if err != nil { - return nil, err - } + var err error + typeParameterList, err = parseTypeParameterList(p) + if err != nil { + return nil, err } parameterList, returnTypeAnnotation, functionBlock, err := diff --git a/runtime/sema/check_composite_declaration.go b/runtime/sema/check_composite_declaration.go index c9a2d2c04d..1f4702411c 100644 --- a/runtime/sema/check_composite_declaration.go +++ b/runtime/sema/check_composite_declaration.go @@ -1758,7 +1758,11 @@ func (checker *Checker) defaultMembersAndOrigins( identifier := function.Identifier.Identifier - functionType := checker.functionType(function.ParameterList, function.ReturnTypeAnnotation) + functionType := checker.functionType( + function.TypeParameterList, + function.ParameterList, + function.ReturnTypeAnnotation, + ) argumentLabels := function.ParameterList.EffectiveArgumentLabels() @@ -2035,6 +2039,7 @@ func (checker *Checker) checkSpecialFunction( } checker.checkFunction( + nil, specialFunction.FunctionDeclaration.ParameterList, nil, functionType, diff --git a/runtime/sema/check_function.go b/runtime/sema/check_function.go index 8797b50f37..901c759447 100644 --- a/runtime/sema/check_function.go +++ b/runtime/sema/check_function.go @@ -21,6 +21,7 @@ package sema import ( "github.com/onflow/cadence/runtime/ast" "github.com/onflow/cadence/runtime/common" + "github.com/onflow/cadence/runtime/errors" ) func (checker *Checker) VisitFunctionDeclaration(declaration *ast.FunctionDeclaration) (_ struct{}) { @@ -80,7 +81,11 @@ func (checker *Checker) visitFunctionDeclaration( functionType := checker.Elaboration.FunctionDeclarationFunctionType(declaration) if functionType == nil { - functionType = checker.functionType(declaration.ParameterList, declaration.ReturnTypeAnnotation) + functionType = checker.functionType( + declaration.TypeParameterList, + declaration.ParameterList, + declaration.ReturnTypeAnnotation, + ) if options.declareFunction { checker.declareFunctionDeclaration(declaration, functionType) @@ -90,6 +95,7 @@ func (checker *Checker) visitFunctionDeclaration( checker.Elaboration.SetFunctionDeclarationFunctionType(declaration, functionType) checker.checkFunction( + declaration.TypeParameterList, declaration.ParameterList, declaration.ReturnTypeAnnotation, functionType, @@ -125,6 +131,7 @@ func (checker *Checker) declareFunctionDeclaration( } func (checker *Checker) checkFunction( + typeParameterList *ast.TypeParameterList, parameterList *ast.ParameterList, returnTypeAnnotation *ast.TypeAnnotation, functionType *FunctionType, @@ -133,6 +140,47 @@ func (checker *Checker) checkFunction( initializationInfo *InitializationInfo, checkResourceLoss bool, ) { + // If type parameters are given, + // resolve generic types in the function type + // to the type bounds of the type parameters. + // + // Type parameters must have type bounds, + // to at least determine resource-kindedness + // (A function cannot be written in a way that it supports + // either resources or non-resources.) + + typeParameters := functionType.TypeParameters + if len(typeParameters) > 0 { + + typeArguments := &TypeParameterTypeOrderedMap{} + + for typeParameterIndex, typeParameter := range typeParameters { + + typeBound := typeParameter.TypeBound + if typeBound == nil { + astTypeParameter := typeParameterList.TypeParameters[typeParameterIndex] + + checker.report(&MissingTypeParameterTypeBoundError{ + Name: typeParameter.Name, + Range: ast.NewUnmeteredRangeFromPositioned(astTypeParameter), + }) + continue + } + + typeArguments.Set(typeParameter, typeBound) + } + + resolvedType := functionType.Resolve(typeArguments) + + if resolvedType != nil { + var ok bool + functionType, ok = resolvedType.(*FunctionType) + if !ok { + panic(errors.NewUnreachableError()) + } + } + } + // check argument labels checker.checkArgumentLabels(parameterList) @@ -414,14 +462,24 @@ func (checker *Checker) declareBefore() { func (checker *Checker) VisitFunctionExpression(expression *ast.FunctionExpression) Type { + // TODO: add support in parser + var typeParameterList *ast.TypeParameterList + parameterList := expression.ParameterList + returnTypeAnnotation := expression.ReturnTypeAnnotation + // TODO: infer - functionType := checker.functionType(expression.ParameterList, expression.ReturnTypeAnnotation) + functionType := checker.functionType( + typeParameterList, + parameterList, + returnTypeAnnotation, + ) checker.Elaboration.SetFunctionExpressionFunctionType(expression, functionType) checker.checkFunction( - expression.ParameterList, - expression.ReturnTypeAnnotation, + typeParameterList, + parameterList, + returnTypeAnnotation, functionType, expression.FunctionBlock, true, diff --git a/runtime/sema/check_interface_declaration.go b/runtime/sema/check_interface_declaration.go index 2b2818fd69..d0bc7a7d87 100644 --- a/runtime/sema/check_interface_declaration.go +++ b/runtime/sema/check_interface_declaration.go @@ -205,6 +205,14 @@ func (checker *Checker) checkInterfaceFunctions( } } + if function.TypeParameterList != nil { + checker.report( + &InvalidTypeParameterizedInterfaceFunctionError{ + Range: ast.NewUnmeteredRangeFromPositioned(function.TypeParameterList), + }, + ) + } + checker.visitFunctionDeclaration( function, functionDeclarationOptions{ diff --git a/runtime/sema/check_transaction_declaration.go b/runtime/sema/check_transaction_declaration.go index aaca803d31..771fa36a79 100644 --- a/runtime/sema/check_transaction_declaration.go +++ b/runtime/sema/check_transaction_declaration.go @@ -182,6 +182,7 @@ func (checker *Checker) visitTransactionPrepareFunction( prepareFunctionType := transactionType.PrepareFunctionType() checker.checkFunction( + nil, prepareFunction.FunctionDeclaration.ParameterList, nil, prepareFunctionType, @@ -231,6 +232,7 @@ func (checker *Checker) visitTransactionExecuteFunction( executeFunctionType := transactionType.ExecuteFunctionType() checker.checkFunction( + nil, &ast.ParameterList{}, nil, executeFunctionType, diff --git a/runtime/sema/checker.go b/runtime/sema/checker.go index 6ed5cab3a6..8d14553d65 100644 --- a/runtime/sema/checker.go +++ b/runtime/sema/checker.go @@ -424,7 +424,11 @@ func (checker *Checker) checkTopLevelDeclarationValidity( } func (checker *Checker) declareGlobalFunctionDeclaration(declaration *ast.FunctionDeclaration) { - functionType := checker.functionType(declaration.ParameterList, declaration.ReturnTypeAnnotation) + functionType := checker.functionType( + declaration.TypeParameterList, + declaration.ParameterList, + declaration.ReturnTypeAnnotation, + ) checker.Elaboration.SetFunctionDeclarationFunctionType(declaration, functionType) checker.declareFunctionDeclaration(declaration, functionType) } @@ -1240,9 +1244,45 @@ func (checker *Checker) ConvertTypeAnnotation(typeAnnotation *ast.TypeAnnotation } func (checker *Checker) functionType( + typeParameterList *ast.TypeParameterList, parameterList *ast.ParameterList, returnTypeAnnotation *ast.TypeAnnotation, ) *FunctionType { + + var convertedTypeParameters []*TypeParameter + if typeParameterList != nil { + + checker.typeActivations.Enter() + defer checker.typeActivations.Leave(func(gauge common.MemoryGauge) ast.Position { + if returnTypeAnnotation != nil { + return returnTypeAnnotation.EndPosition(gauge) + } else { + return parameterList.EndPos + } + }) + + // All type parameters are converted at once, + // so type bounds may currently not refer to previous type parameters + + convertedTypeParameters = checker.typeParameters(typeParameterList) + + for typeParameterIndex, typeParameter := range typeParameterList.TypeParameters { + convertedTypeParameter := convertedTypeParameters[typeParameterIndex] + + genericType := &GenericType{ + TypeParameter: convertedTypeParameter, + } + + _, err := checker.typeActivations.declareType(typeDeclaration{ + identifier: typeParameter.Identifier, + ty: genericType, + declarationKind: common.DeclarationKindTypeParameter, + allowOuterScopeShadowing: false, + }) + checker.report(err) + + } + } convertedParameters := checker.parameters(parameterList) convertedReturnTypeAnnotation := VoidTypeAnnotation @@ -1252,11 +1292,40 @@ func (checker *Checker) functionType( } return &FunctionType{ + TypeParameters: convertedTypeParameters, Parameters: convertedParameters, ReturnTypeAnnotation: convertedReturnTypeAnnotation, } } +func (checker *Checker) typeParameters(typeParameterList *ast.TypeParameterList) []*TypeParameter { + + var typeParameters []*TypeParameter + + typeParameterCount := len(typeParameterList.TypeParameters) + if typeParameterCount > 0 { + typeParameters = make([]*TypeParameter, typeParameterCount) + + for i, typeParameter := range typeParameterList.TypeParameters { + + typeBoundAnnotation := typeParameter.TypeBound + var convertedTypeBound Type + if typeBoundAnnotation != nil { + convertedTypeBoundAnnotation := checker.ConvertTypeAnnotation(typeBoundAnnotation) + checker.checkTypeAnnotation(convertedTypeBoundAnnotation, typeBoundAnnotation) + convertedTypeBound = convertedTypeBoundAnnotation.Type + } + + typeParameters[i] = &TypeParameter{ + Name: typeParameter.Identifier.Identifier, + TypeBound: convertedTypeBound, + } + } + } + + return typeParameters +} + func (checker *Checker) parameters(parameterList *ast.ParameterList) []Parameter { var parameters []Parameter diff --git a/runtime/sema/errors.go b/runtime/sema/errors.go index 9012a60450..542b4d0aba 100644 --- a/runtime/sema/errors.go +++ b/runtime/sema/errors.go @@ -4046,3 +4046,41 @@ func (*AttachmentsNotEnabledError) IsUserError() {} func (e *AttachmentsNotEnabledError) Error() string { return "attachments are not enabled and cannot be used in this environment" } + +// MissingTypeParameterTypeBoundError + +type MissingTypeParameterTypeBoundError struct { + Name string + ast.Range +} + +var _ SemanticError = &MissingTypeParameterTypeBoundError{} +var _ errors.UserError = &MissingTypeParameterTypeBoundError{} + +func (*MissingTypeParameterTypeBoundError) isSemanticError() {} + +func (*MissingTypeParameterTypeBoundError) IsUserError() {} + +func (e *MissingTypeParameterTypeBoundError) Error() string { + return fmt.Sprintf( + "missing type bound for type parameter `%s`", + e.Name, + ) +} + +// InvalidTypeParameterizedInterfaceFunctionError + +type InvalidTypeParameterizedInterfaceFunctionError struct { + ast.Range +} + +var _ SemanticError = &InvalidTypeParameterizedInterfaceFunctionError{} +var _ errors.UserError = &InvalidTypeParameterizedInterfaceFunctionError{} + +func (*InvalidTypeParameterizedInterfaceFunctionError) isSemanticError() {} + +func (*InvalidTypeParameterizedInterfaceFunctionError) IsUserError() {} + +func (e *InvalidTypeParameterizedInterfaceFunctionError) Error() string { + return "invalid type parameters in interface function" +} diff --git a/runtime/sema/gen/main.go b/runtime/sema/gen/main.go index 50e9ee975f..10290bed5a 100644 --- a/runtime/sema/gen/main.go +++ b/runtime/sema/gen/main.go @@ -64,7 +64,6 @@ var parsedHeaderTemplate = template.Must(template.New("header").Parse(headerTemp var parserConfig = parser.Config{ StaticModifierEnabled: true, NativeModifierEnabled: true, - TypeParametersEnabled: true, } func initialUpper(s string) string { diff --git a/runtime/tests/checker/genericfunction_test.go b/runtime/tests/checker/genericfunction_test.go index 1282d8753a..084be0160b 100644 --- a/runtime/tests/checker/genericfunction_test.go +++ b/runtime/tests/checker/genericfunction_test.go @@ -50,7 +50,7 @@ func parseAndCheckWithTestValue(t *testing.T, code string, ty sema.Type) (*sema. ) } -func TestCheckGenericFunction(t *testing.T) { +func TestCheckGenericFunctionInvocation(t *testing.T) { t.Parallel() @@ -925,7 +925,7 @@ func TestCheckBorrowOfCapabilityWithoutTypeArgument(t *testing.T) { require.NoError(t, err) } -func TestCheckUnparameterizedTypeInstantiationE(t *testing.T) { +func TestCheckUnparameterizedTypeInstantiation(t *testing.T) { t.Parallel() @@ -939,3 +939,167 @@ func TestCheckUnparameterizedTypeInstantiationE(t *testing.T) { assert.IsType(t, &sema.UnparameterizedTypeInstantiationError{}, errs[0]) } + +func TestCheckGenericFunctionDeclaration(t *testing.T) { + + t.Parallel() + + t.Run("global, struct", func(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, ` + fun head(_ items: [T]): T? { + if items.length < 1 { + return nil + } + return items[0] + } + + let x: Int? = head([1, 2, 3]) + `) + + require.NoError(t, err) + }) + + t.Run("global, resource", func(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, ` + fun head(_ items: @[T]): @T? { + if items.length < 1 { + destroy items + return nil + } + let item <-items.remove(at: 0) + destroy items + return <-item + } + + resource R {} + + let x: @R? <- head(<-[<-create R()]) + `) + + require.NoError(t, err) + }) + + t.Run("missing type parameter bound", func(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, ` + fun test() {} + `) + + errs := RequireCheckerErrors(t, err, 1) + + require.IsType(t, &sema.MissingTypeParameterTypeBoundError{}, errs[0]) + }) + + t.Run("too many type arguments", func(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, ` + fun test() {} + + let x = test() + `) + + errs := RequireCheckerErrors(t, err, 1) + + require.IsType(t, &sema.InvalidTypeArgumentCountError{}, errs[0]) + }) + + t.Run("type parameter usage in function body", func(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, ` + fun test(): Type { + return Type() + } + `) + + errs := RequireCheckerErrors(t, err, 2) + + assert.IsType(t, &sema.NotDeclaredError{}, errs[0]) + assert.IsType(t, &sema.TypeParameterTypeInferenceError{}, errs[1]) + }) + + t.Run("type parameter usage in following type parameter", func(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, ` + fun test(_ u: U): U { return u } + `) + + errs := RequireCheckerErrors(t, err, 1) + + assert.IsType(t, &sema.NotDeclaredError{}, errs[0]) + }) + + t.Run("composite, struct", func(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, ` + struct S { + fun head(_ items: [T]): T? { + if items.length < 1 { + return nil + } + return items[0] + } + } + + let x: Int? = S().head([1, 2, 3]) + `) + + require.NoError(t, err) + }) + + t.Run("composite, resource", func(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, ` + struct S { + fun head(_ items: @[T]): @T? { + if items.length < 1 { + destroy items + return nil + } + let item <-items.remove(at: 0) + destroy items + return <-item + } + } + + resource R {} + + let x: @R? <- S().head(<-[<-create R()]) + `) + + require.NoError(t, err) + }) + + t.Run("interface", func(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, ` + struct interface SI { + fun foo() + } + `) + + errs := RequireCheckerErrors(t, err, 1) + + assert.IsType(t, &sema.InvalidTypeParameterizedInterfaceFunctionError{}, errs[0]) + }) + +} diff --git a/runtime/tests/interpreter/generic_test.go b/runtime/tests/interpreter/generic_test.go new file mode 100644 index 0000000000..f3241758b4 --- /dev/null +++ b/runtime/tests/interpreter/generic_test.go @@ -0,0 +1,91 @@ +/* + * Cadence - The resource-oriented smart contract programming language + * + * Copyright Dapper Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package interpreter_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/onflow/cadence/runtime/interpreter" + "github.com/onflow/cadence/runtime/tests/utils" +) + +func TestInterpretGenericFunctionDeclaration(t *testing.T) { + + t.Parallel() + + t.Run("global", func(t *testing.T) { + + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + fun head(_ items: [T]): T? { + if items.length < 1 { + return nil + } + return items[0] + } + + fun test(): Int { + return head([1, 2, 3])! + } + `) + + result, err := inter.Invoke("test") + require.NoError(t, err) + + utils.RequireValuesEqual( + t, + inter, + interpreter.NewIntValueFromInt64(nil, 1), + result, + ) + }) + + t.Run("composite", func(t *testing.T) { + + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + struct S { + fun head(_ items: [T]): T? { + if items.length < 1 { + return nil + } + return items[0] + } + } + + fun test(): Int { + return S().head([1, 2, 3])! + } + `) + + result, err := inter.Invoke("test") + require.NoError(t, err) + + utils.RequireValuesEqual( + t, + inter, + interpreter.NewIntValueFromInt64(nil, 1), + result, + ) + }) +}