From a70839748cf3d2a19a44519253ffbe03f567e3c1 Mon Sep 17 00:00:00 2001 From: Tanner Stirrat Date: Mon, 10 Feb 2025 08:37:37 -0700 Subject: [PATCH] Port expiration compiler changes into composableschemadsl --- pkg/composableschemadsl/compiler/compiler.go | 23 ++++- .../compiler/compiler_test.go | 57 +++++++++++++ .../compiler/translator.go | 85 ++++++++++++++----- .../generator/generator.go | 45 ++++++++-- .../generator/generator_impl.go | 17 ++-- .../generator/generator_test.go | 27 +++++- .../lexer/flaggablelexer.go | 14 +-- .../lexer/flaggablelexer_test.go | 2 +- pkg/composableschemadsl/parser/parser.go | 21 +++-- pkg/composableschemadsl/parser/parser_impl.go | 16 ++-- pkg/composableschemadsl/parser/parser_test.go | 4 + .../parser/tests/duplicate_use_statement.zed | 7 ++ .../duplicate_use_statement.zed.expected | 71 ++++++++++++++++ .../parser/tests/useafterdef.zed.expected | 17 ++-- pkg/schemadsl/compiler/compiler.go | 2 +- pkg/schemadsl/compiler/compiler_test.go | 13 +++ pkg/schemadsl/compiler/translator.go | 51 +++++++---- pkg/schemadsl/lexer/flaggablelexer.go | 14 +-- pkg/schemadsl/lexer/flaggablelexer_test.go | 2 +- pkg/schemadsl/parser/parser.go | 21 +++-- pkg/schemadsl/parser/parser_impl.go | 16 ++-- .../parser/tests/duplicate_use_statement.zed | 7 ++ .../duplicate_use_statement.zed.expected | 71 ++++++++++++++++ .../parser/tests/useafterdef.zed.expected | 17 ++-- 24 files changed, 509 insertions(+), 111 deletions(-) create mode 100644 pkg/composableschemadsl/parser/tests/duplicate_use_statement.zed create mode 100644 pkg/composableschemadsl/parser/tests/duplicate_use_statement.zed.expected create mode 100644 pkg/schemadsl/parser/tests/duplicate_use_statement.zed create mode 100644 pkg/schemadsl/parser/tests/duplicate_use_statement.zed.expected diff --git a/pkg/composableschemadsl/compiler/compiler.go b/pkg/composableschemadsl/compiler/compiler.go index 81b4fcfdd3..9c0e10ee88 100644 --- a/pkg/composableschemadsl/compiler/compiler.go +++ b/pkg/composableschemadsl/compiler/compiler.go @@ -5,6 +5,7 @@ import ( "fmt" "google.golang.org/protobuf/proto" + "k8s.io/utils/strings/slices" "github.com/authzed/spicedb/pkg/composableschemadsl/dslshape" "github.com/authzed/spicedb/pkg/composableschemadsl/input" @@ -53,6 +54,7 @@ func (cs CompiledSchema) SourcePositionToRunePosition(source input.Source, posit type config struct { skipValidation bool objectTypePrefix *string + allowedFlags []string // In an import context, this is the folder containing // the importing schema (as opposed to imported schemas) sourceFolder string @@ -76,6 +78,16 @@ func AllowUnprefixedObjectType() ObjectPrefixOption { return func(cfg *config) { cfg.objectTypePrefix = new(string) } } +const expirationFlag = "expiration" + +func DisallowExpirationFlag() Option { + return func(cfg *config) { + cfg.allowedFlags = slices.Filter([]string{}, cfg.allowedFlags, func(s string) bool { + return s != expirationFlag + }) + } +} + // Config that supplies the root source folder for compilation. Required // for relative import syntax to work properly. func SourceFolder(sourceFolder string) Option { @@ -88,7 +100,13 @@ type ObjectPrefixOption func(*config) // Compile compilers the input schema into a set of namespace definition protos. func Compile(schema InputSchema, prefix ObjectPrefixOption, opts ...Option) (*CompiledSchema, error) { - cfg := &config{} + cfg := &config{ + allowedFlags: make([]string, 0, 1), + } + + // Enable `expiration` flag by default. + cfg.allowedFlags = append(cfg.allowedFlags, expirationFlag) + prefix(cfg) // required option for _, fn := range opts { @@ -111,11 +129,12 @@ func Compile(schema InputSchema, prefix ObjectPrefixOption, opts ...Option) (*Co return nil, err } - compiled, err := translate(translationContext{ + compiled, err := translate(&translationContext{ objectTypePrefix: cfg.objectTypePrefix, mapper: mapper, schemaString: schema.SchemaString, skipValidate: cfg.skipValidation, + allowedFlags: cfg.allowedFlags, existingNames: mapz.NewSet[string](), compiledPartials: make(map[string][]*core.Relation), unresolvedPartials: mapz.NewMultiMap[string, *dslNode](), diff --git a/pkg/composableschemadsl/compiler/compiler_test.go b/pkg/composableschemadsl/compiler/compiler_test.go index 79f65fd896..0624344fb8 100644 --- a/pkg/composableschemadsl/compiler/compiler_test.go +++ b/pkg/composableschemadsl/compiler/compiler_test.go @@ -1206,6 +1206,63 @@ func TestCompile(t *testing.T) { ), }, }, + { + "relation with expiration trait", + withTenantPrefix, + `use expiration + + definition simple { + relation viewer: user with expiration + }`, + "", + []SchemaDefinition{ + namespace.Namespace("sometenant/simple", + namespace.MustRelation("viewer", nil, + namespace.AllowedRelationWithExpiration("sometenant/user", "..."), + ), + ), + }, + }, + { + "duplicate use pragmas", + withTenantPrefix, + ` + use expiration + use expiration + + definition simple { + relation viewer: user with expiration + }`, + `found duplicate use flag`, + []SchemaDefinition{}, + }, + { + "expiration use without use expiration", + withTenantPrefix, + ` + definition simple { + relation viewer: user with expiration + }`, + `expiration flag is not enabled`, + []SchemaDefinition{}, + }, + { + "relation with expiration trait and caveat", + withTenantPrefix, + `use expiration + + definition simple { + relation viewer: user with somecaveat and expiration + }`, + "", + []SchemaDefinition{ + namespace.Namespace("sometenant/simple", + namespace.MustRelation("viewer", nil, + namespace.AllowedRelationWithCaveatAndExpiration("sometenant/user", "...", namespace.AllowedCaveat("sometenant/somecaveat")), + ), + ), + }, + }, } for _, test := range tests { diff --git a/pkg/composableschemadsl/compiler/translator.go b/pkg/composableschemadsl/compiler/translator.go index 4ec48d888d..3d54598c57 100644 --- a/pkg/composableschemadsl/compiler/translator.go +++ b/pkg/composableschemadsl/compiler/translator.go @@ -5,6 +5,7 @@ import ( "container/list" "fmt" "path/filepath" + "slices" "strings" "github.com/ccoveille/go-safecast" @@ -25,6 +26,8 @@ type translationContext struct { mapper input.PositionMapper schemaString string skipValidate bool + allowedFlags []string + enabledFlags []string existingNames *mapz.Set[string] // The mapping of partial name -> relations represented by the partial compiledPartials map[string][]*core.Relation @@ -33,7 +36,7 @@ type translationContext struct { unresolvedPartials *mapz.MultiMap[string, *dslNode] } -func (tctx translationContext) prefixedPath(definitionName string) (string, error) { +func (tctx *translationContext) prefixedPath(definitionName string) (string, error) { var prefix, name string if err := stringz.SplitInto(definitionName, "/", &prefix, &name); err != nil { if tctx.objectTypePrefix == nil { @@ -52,7 +55,7 @@ func (tctx translationContext) prefixedPath(definitionName string) (string, erro const Ellipsis = "..." -func translate(tctx translationContext, root *dslNode) (*CompiledSchema, error) { +func translate(tctx *translationContext, root *dslNode) (*CompiledSchema, error) { orderedDefinitions := make([]SchemaDefinition, 0, len(root.GetChildren())) var objectDefinitions []*core.NamespaceDefinition var caveatDefinitions []*core.CaveatDefinition @@ -71,6 +74,12 @@ func translate(tctx translationContext, root *dslNode) (*CompiledSchema, error) for _, topLevelNode := range root.GetChildren() { switch topLevelNode.GetType() { + case dslshape.NodeTypeUseFlag: + err := translateUseFlag(tctx, topLevelNode) + if err != nil { + return nil, err + } + case dslshape.NodeTypeCaveatDefinition: log.Trace().Msg("adding caveat definition") // TODO: Maybe refactor these in terms of a generic function? @@ -115,7 +124,7 @@ func translate(tctx translationContext, root *dslNode) (*CompiledSchema, error) }, nil } -func translateCaveatDefinition(tctx translationContext, defNode *dslNode) (*core.CaveatDefinition, error) { +func translateCaveatDefinition(tctx *translationContext, defNode *dslNode) (*core.CaveatDefinition, error) { definitionName, err := defNode.GetString(dslshape.NodeCaveatDefinitionPredicateName) if err != nil { return nil, defNode.WithSourceErrorf(definitionName, "invalid definition name: %w", err) @@ -197,7 +206,7 @@ func translateCaveatDefinition(tctx translationContext, defNode *dslNode) (*core return def, nil } -func translateCaveatTypeReference(tctx translationContext, typeRefNode *dslNode) (*caveattypes.VariableType, error) { +func translateCaveatTypeReference(tctx *translationContext, typeRefNode *dslNode) (*caveattypes.VariableType, error) { typeName, err := typeRefNode.GetString(dslshape.NodeCaveatTypeReferencePredicateType) if err != nil { return nil, typeRefNode.WithSourceErrorf(typeName, "invalid type name: %w", err) @@ -221,7 +230,7 @@ func translateCaveatTypeReference(tctx translationContext, typeRefNode *dslNode) return constructedType, nil } -func translateObjectDefinition(tctx translationContext, defNode *dslNode) (*core.NamespaceDefinition, error) { +func translateObjectDefinition(tctx *translationContext, defNode *dslNode) (*core.NamespaceDefinition, error) { definitionName, err := defNode.GetString(dslshape.NodeDefinitionPredicateName) if err != nil { return nil, defNode.WithSourceErrorf(definitionName, "invalid definition name: %w", err) @@ -269,7 +278,7 @@ func translateObjectDefinition(tctx translationContext, defNode *dslNode) (*core // A value of true treats that as an error state, since all partials should be resolved when translating definitions, // where the false value returns the name of the partial for collection for future processing // when translating partials. -func translateRelationsAndPermissions(tctx translationContext, astNode *dslNode, errorOnMissingReference bool) ([]*core.Relation, string, error) { +func translateRelationsAndPermissions(tctx *translationContext, astNode *dslNode, errorOnMissingReference bool) ([]*core.Relation, string, error) { relationsAndPermissions := []*core.Relation{} for _, definitionChildNode := range astNode.GetChildren() { if definitionChildNode.GetType() == dslshape.NodeTypeComment { @@ -345,7 +354,7 @@ func normalizeComment(value string) string { return strings.Join(lines, "\n") } -func translateRelationOrPermission(tctx translationContext, relOrPermNode *dslNode) (*core.Relation, error) { +func translateRelationOrPermission(tctx *translationContext, relOrPermNode *dslNode) (*core.Relation, error) { switch relOrPermNode.GetType() { case dslshape.NodeTypeRelation: rel, err := translateRelation(tctx, relOrPermNode) @@ -370,7 +379,7 @@ func translateRelationOrPermission(tctx translationContext, relOrPermNode *dslNo } } -func translateRelation(tctx translationContext, relationNode *dslNode) (*core.Relation, error) { +func translateRelation(tctx *translationContext, relationNode *dslNode) (*core.Relation, error) { relationName, err := relationNode.GetString(dslshape.NodePredicateName) if err != nil { return nil, relationNode.Errorf("invalid relation name: %w", err) @@ -400,7 +409,7 @@ func translateRelation(tctx translationContext, relationNode *dslNode) (*core.Re return relation, nil } -func translatePermission(tctx translationContext, permissionNode *dslNode) (*core.Relation, error) { +func translatePermission(tctx *translationContext, permissionNode *dslNode) (*core.Relation, error) { permissionName, err := permissionNode.GetString(dslshape.NodePredicateName) if err != nil { return nil, permissionNode.Errorf("invalid permission name: %w", err) @@ -430,7 +439,7 @@ func translatePermission(tctx translationContext, permissionNode *dslNode) (*cor return permission, nil } -func translateBinary(tctx translationContext, expressionNode *dslNode) (*core.SetOperation_Child, *core.SetOperation_Child, error) { +func translateBinary(tctx *translationContext, expressionNode *dslNode) (*core.SetOperation_Child, *core.SetOperation_Child, error) { leftChild, err := expressionNode.Lookup(dslshape.NodeExpressionPredicateLeftExpr) if err != nil { return nil, nil, err @@ -454,7 +463,7 @@ func translateBinary(tctx translationContext, expressionNode *dslNode) (*core.Se return leftOperation, rightOperation, nil } -func translateExpression(tctx translationContext, expressionNode *dslNode) (*core.UsersetRewrite, error) { +func translateExpression(tctx *translationContext, expressionNode *dslNode) (*core.UsersetRewrite, error) { translated, err := translateExpressionDirect(tctx, expressionNode) if err != nil { return translated, err @@ -482,7 +491,7 @@ func collapseOps(op *core.SetOperation_Child, handler func(rewrite *core.Userset return collapsed } -func translateExpressionDirect(tctx translationContext, expressionNode *dslNode) (*core.UsersetRewrite, error) { +func translateExpressionDirect(tctx *translationContext, expressionNode *dslNode) (*core.UsersetRewrite, error) { // For union and intersection, we collapse a tree of binary operations into a flat list containing child // operations of the *same* type. translate := func( @@ -528,7 +537,7 @@ func translateExpressionDirect(tctx translationContext, expressionNode *dslNode) } } -func translateExpressionOperation(tctx translationContext, expressionOpNode *dslNode) (*core.SetOperation_Child, error) { +func translateExpressionOperation(tctx *translationContext, expressionOpNode *dslNode) (*core.SetOperation_Child, error) { translated, err := translateExpressionOperationDirect(tctx, expressionOpNode) if err != nil { return translated, err @@ -538,7 +547,7 @@ func translateExpressionOperation(tctx translationContext, expressionOpNode *dsl return translated, nil } -func translateExpressionOperationDirect(tctx translationContext, expressionOpNode *dslNode) (*core.SetOperation_Child, error) { +func translateExpressionOperationDirect(tctx *translationContext, expressionOpNode *dslNode) (*core.SetOperation_Child, error) { switch expressionOpNode.GetType() { case dslshape.NodeTypeIdentifier: referencedRelationName, err := expressionOpNode.GetString(dslshape.NodeIdentiferPredicateValue) @@ -605,7 +614,7 @@ func translateExpressionOperationDirect(tctx translationContext, expressionOpNod } } -func translateAllowedRelations(tctx translationContext, typeRefNode *dslNode) ([]*core.AllowedRelation, error) { +func translateAllowedRelations(tctx *translationContext, typeRefNode *dslNode) ([]*core.AllowedRelation, error) { switch typeRefNode.GetType() { case dslshape.NodeTypeTypeReference: references := []*core.AllowedRelation{} @@ -631,7 +640,7 @@ func translateAllowedRelations(tctx translationContext, typeRefNode *dslNode) ([ } } -func translateSpecificTypeReference(tctx translationContext, typeRefNode *dslNode) (*core.AllowedRelation, error) { +func translateSpecificTypeReference(tctx *translationContext, typeRefNode *dslNode) (*core.AllowedRelation, error) { typePath, err := typeRefNode.GetString(dslshape.NodeSpecificReferencePredicateType) if err != nil { return nil, typeRefNode.Errorf("invalid type name: %w", err) @@ -680,11 +689,34 @@ func translateSpecificTypeReference(tctx translationContext, typeRefNode *dslNod }, } + // Add the caveat(s), if any. err = addWithCaveats(tctx, typeRefNode, ref) if err != nil { return nil, typeRefNode.Errorf("invalid caveat: %w", err) } + // Add the expiration trait, if any. + if traitNode, err := typeRefNode.Lookup(dslshape.NodeSpecificReferencePredicateTrait); err == nil { + traitName, err := traitNode.GetString(dslshape.NodeTraitPredicateTrait) + if err != nil { + return nil, typeRefNode.Errorf("invalid trait: %w", err) + } + + if traitName != "expiration" { + return nil, typeRefNode.Errorf("invalid trait: %s", traitName) + } + + if !slices.Contains(tctx.allowedFlags, "expiration") { + return nil, typeRefNode.Errorf("expiration trait is not allowed") + } + + if !slices.Contains(tctx.enabledFlags, "expiration") { + return nil, typeRefNode.Errorf("expiration flag is not enabled; add `use expiration` to top of file") + } + + ref.RequiredExpiration = &core.ExpirationTrait{} + } + if !tctx.skipValidate { if err := ref.Validate(); err != nil { return nil, typeRefNode.Errorf("invalid type relation: %w", err) @@ -695,7 +727,7 @@ func translateSpecificTypeReference(tctx translationContext, typeRefNode *dslNod return ref, nil } -func addWithCaveats(tctx translationContext, typeRefNode *dslNode, ref *core.AllowedRelation) error { +func addWithCaveats(tctx *translationContext, typeRefNode *dslNode, ref *core.AllowedRelation) error { caveats := typeRefNode.List(dslshape.NodeSpecificReferencePredicateCaveat) if len(caveats) == 0 { return nil @@ -809,7 +841,7 @@ func translateImports(itctx importResolutionContext, root *dslNode) error { return nil } -func collectPartials(tctx translationContext, rootNode *dslNode) error { +func collectPartials(tctx *translationContext, rootNode *dslNode) error { for _, topLevelNode := range rootNode.GetChildren() { if topLevelNode.GetType() == dslshape.NodeTypePartial { err := translatePartial(tctx, topLevelNode) @@ -825,7 +857,7 @@ func collectPartials(tctx translationContext, rootNode *dslNode) error { } // This function modifies the translation context, so we don't need to return anything from it. -func translatePartial(tctx translationContext, partialNode *dslNode) error { +func translatePartial(tctx *translationContext, partialNode *dslNode) error { partialName, err := partialNode.GetString(dslshape.NodePartialPredicateName) if err != nil { return err @@ -863,7 +895,7 @@ func translatePartial(tctx translationContext, partialNode *dslNode) error { // NOTE: we treat partial references in definitions and partials differently because a missing partial // reference in definition compilation is an error state, where a missing partial reference in a // partial definition is an indeterminate state. -func translatePartialReference(tctx translationContext, partialReferenceNode *dslNode, errorOnMissingReference bool) ([]*core.Relation, string, error) { +func translatePartialReference(tctx *translationContext, partialReferenceNode *dslNode, errorOnMissingReference bool) ([]*core.Relation, string, error) { name, err := partialReferenceNode.GetString(dslshape.NodePartialReferencePredicateName) if err != nil { return nil, "", err @@ -879,3 +911,16 @@ func translatePartialReference(tctx translationContext, partialReferenceNode *ds } return relationsAndPermissions, "", nil } + +// Translate use node and add flag to list of enabled flags +func translateUseFlag(tctx *translationContext, useFlagNode *dslNode) error { + flagName, err := useFlagNode.GetString(dslshape.NodeUseFlagPredicateName) + if err != nil { + return err + } + if slices.Contains(tctx.enabledFlags, flagName) { + return useFlagNode.Errorf("found duplicate use flag: %s", flagName) + } + tctx.enabledFlags = append(tctx.enabledFlags, flagName) + return nil +} diff --git a/pkg/composableschemadsl/generator/generator.go b/pkg/composableschemadsl/generator/generator.go index 36f48475f0..48212642f5 100644 --- a/pkg/composableschemadsl/generator/generator.go +++ b/pkg/composableschemadsl/generator/generator.go @@ -11,6 +11,7 @@ import ( "github.com/authzed/spicedb/pkg/caveats" caveattypes "github.com/authzed/spicedb/pkg/caveats/types" "github.com/authzed/spicedb/pkg/composableschemadsl/compiler" + "github.com/authzed/spicedb/pkg/genutil/mapz" "github.com/authzed/spicedb/pkg/graph" "github.com/authzed/spicedb/pkg/namespace" core "github.com/authzed/spicedb/pkg/proto/core/v1" @@ -26,6 +27,8 @@ const MaxSingleLineCommentLength = 70 // 80 - the comment parts and some padding // GenerateSchema generates a DSL view of the given schema. func GenerateSchema(definitions []compiler.SchemaDefinition) (string, bool, error) { generated := make([]string, 0, len(definitions)) + flags := mapz.NewSet[string]() + result := true for _, definition := range definitions { switch def := definition.(type) { @@ -39,19 +42,29 @@ func GenerateSchema(definitions []compiler.SchemaDefinition) (string, bool, erro generated = append(generated, generatedCaveat) case *core.NamespaceDefinition: - generatedSchema, ok, err := GenerateSource(def) + generatedSchema, defFlags, ok, err := generateDefinitionSource(def) if err != nil { return "", false, err } result = result && ok generated = append(generated, generatedSchema) + flags.Extend(defFlags) default: return "", false, spiceerrors.MustBugf("unknown type of definition %T in GenerateSchema", def) } } + if !flags.IsEmpty() { + flagsSlice := flags.AsSlice() + sort.Strings(flagsSlice) + + for _, flag := range flagsSlice { + generated = append([]string{"use " + flag}, generated...) + } + } + return strings.Join(generated, "\n\n"), result, nil } @@ -74,19 +87,25 @@ func GenerateCaveatSource(caveat *core.CaveatDefinition) (string, bool, error) { // GenerateSource generates a DSL view of the given namespace definition. func GenerateSource(namespace *core.NamespaceDefinition) (string, bool, error) { + source, _, ok, err := generateDefinitionSource(namespace) + return source, ok, err +} + +func generateDefinitionSource(namespace *core.NamespaceDefinition) (string, []string, bool, error) { generator := &sourceGenerator{ indentationLevel: 0, hasNewline: true, hasBlankline: true, hasNewScope: true, + flags: mapz.NewSet[string](), } err := generator.emitNamespace(namespace) if err != nil { - return "", false, err + return "", nil, false, err } - return generator.buf.String(), !generator.hasIssue, nil + return generator.buf.String(), generator.flags.AsSlice(), !generator.hasIssue, nil } // GenerateRelationSource generates a DSL view of the given relation definition. @@ -237,9 +256,25 @@ func (sg *sourceGenerator) emitAllowedRelation(allowedRelation *core.AllowedRela if allowedRelation.GetPublicWildcard() != nil { sg.append(":*") } - if allowedRelation.GetRequiredCaveat() != nil { + + hasExpirationTrait := allowedRelation.GetRequiredExpiration() != nil + hasCaveat := allowedRelation.GetRequiredCaveat() != nil + + if hasExpirationTrait || hasCaveat { sg.append(" with ") - sg.append(allowedRelation.RequiredCaveat.CaveatName) + if hasCaveat { + sg.append(allowedRelation.RequiredCaveat.CaveatName) + } + + if hasExpirationTrait { + sg.flags.Add("expiration") + + if hasCaveat { + sg.append(" and ") + } + + sg.append("expiration") + } } } diff --git a/pkg/composableschemadsl/generator/generator_impl.go b/pkg/composableschemadsl/generator/generator_impl.go index 5bc11c0ba5..e8f9db06a5 100644 --- a/pkg/composableschemadsl/generator/generator_impl.go +++ b/pkg/composableschemadsl/generator/generator_impl.go @@ -2,16 +2,19 @@ package generator import ( "strings" + + "github.com/authzed/spicedb/pkg/genutil/mapz" ) type sourceGenerator struct { - buf strings.Builder // The buffer for the new source code. - indentationLevel int // The current indentation level. - hasNewline bool // Whether there is a newline at the end of the buffer. - hasBlankline bool // Whether there is a blank line at the end of the buffer. - hasIssue bool // Whether there is a translation issue. - hasNewScope bool // Whether there is a new scope at the end of the buffer. - existingLineLength int // Length of the existing line. + buf strings.Builder // The buffer for the new source code. + indentationLevel int // The current indentation level. + hasNewline bool // Whether there is a newline at the end of the buffer. + hasBlankline bool // Whether there is a blank line at the end of the buffer. + hasIssue bool // Whether there is a translation issue. + hasNewScope bool // Whether there is a new scope at the end of the buffer. + existingLineLength int // Length of the existing line. + flags *mapz.Set[string] // The flags added while generating. } // ensureBlankLineOrNewScope ensures that there is a blank line or new scope at the tail of the buffer. If not, diff --git a/pkg/composableschemadsl/generator/generator_test.go b/pkg/composableschemadsl/generator/generator_test.go index 2519f2d9b4..b10596565e 100644 --- a/pkg/composableschemadsl/generator/generator_test.go +++ b/pkg/composableschemadsl/generator/generator_test.go @@ -304,7 +304,6 @@ definition foos/test { relation somerel: foos/bars }`, }, - { "full example", ` @@ -360,6 +359,32 @@ definition foos/document { }`, `definition document { permission first = rela->relb + relc.any(reld) + rele.all(relf) +}`, + }, + { + "expiration trait", + `use expiration + + definition document{ + relation viewer: user with expiration + relation editor: user with somecaveat and expiration + }`, + `use expiration + +definition document { + relation viewer: user with expiration + relation editor: user with somecaveat and expiration +}`, + }, + { + "unused expiration flag", + `use expiration + + definition document{ + relation viewer: user + }`, + `definition document { + relation viewer: user }`, }, } diff --git a/pkg/composableschemadsl/lexer/flaggablelexer.go b/pkg/composableschemadsl/lexer/flaggablelexer.go index d0700337bf..1ae99af777 100644 --- a/pkg/composableschemadsl/lexer/flaggablelexer.go +++ b/pkg/composableschemadsl/lexer/flaggablelexer.go @@ -1,28 +1,28 @@ package lexer -// FlaggableLexler wraps a lexer, automatically translating tokens based on flags, if any. -type FlaggableLexler struct { +// FlaggableLexer wraps a lexer, automatically translating tokens based on flags, if any. +type FlaggableLexer struct { lex *Lexer // a reference to the lexer used for tokenization enabledFlags map[string]transformer // flags that are enabled seenDefinition bool afterUseIdentifier bool } -// NewFlaggableLexler returns a new FlaggableLexler for the given lexer. -func NewFlaggableLexler(lex *Lexer) *FlaggableLexler { - return &FlaggableLexler{ +// NewFlaggableLexer returns a new FlaggableLexer for the given lexer. +func NewFlaggableLexer(lex *Lexer) *FlaggableLexer { + return &FlaggableLexer{ lex: lex, enabledFlags: map[string]transformer{}, } } // Close stops the lexer from running. -func (l *FlaggableLexler) Close() { +func (l *FlaggableLexer) Close() { l.lex.Close() } // NextToken returns the next token found in the lexer. -func (l *FlaggableLexler) NextToken() Lexeme { +func (l *FlaggableLexer) NextToken() Lexeme { nextToken := l.lex.nextToken() // Look for `use somefeature` diff --git a/pkg/composableschemadsl/lexer/flaggablelexer_test.go b/pkg/composableschemadsl/lexer/flaggablelexer_test.go index 38ed27fd0b..15a56e8466 100644 --- a/pkg/composableschemadsl/lexer/flaggablelexer_test.go +++ b/pkg/composableschemadsl/lexer/flaggablelexer_test.go @@ -62,7 +62,7 @@ func TestFlaggableLexer(t *testing.T) { } func performFlaggedLex(t *lexerTest) (tokens []Lexeme) { - lexer := NewFlaggableLexler(Lex(input.Source(t.name), t.input)) + lexer := NewFlaggableLexer(Lex(input.Source(t.name), t.input)) for { token := lexer.NextToken() tokens = append(tokens, token) diff --git a/pkg/composableschemadsl/parser/parser.go b/pkg/composableschemadsl/parser/parser.go index f97ec22161..08eedc860b 100644 --- a/pkg/composableschemadsl/parser/parser.go +++ b/pkg/composableschemadsl/parser/parser.go @@ -48,12 +48,6 @@ Loop: break Loop } - if !hasSeenDefinition { - if p.isKeyword("use") { - rootNode.Connect(dslshape.NodePredicateChild, p.consumeUseFlag()) - } - } - // Consume a statement terminator if one was found. p.tryConsumeStatementTerminator() @@ -66,6 +60,9 @@ Loop: // caveat somecaveat (...) { ... } switch { + case p.isKeyword("use"): + rootNode.Connect(dslshape.NodePredicateChild, p.consumeUseFlag(hasSeenDefinition)) + case p.isKeyword("definition"): hasSeenDefinition = true rootNode.Connect(dslshape.NodePredicateChild, p.consumeDefinition()) @@ -251,7 +248,7 @@ func (p *sourceParser) consumeCaveatTypeReference() AstNode { // consumeUseFlag attempts to consume a use flag. // ``` use flagname ``` -func (p *sourceParser) consumeUseFlag() AstNode { +func (p *sourceParser) consumeUseFlag(afterDefinition bool) AstNode { useNode := p.startNode(dslshape.NodeTypeUseFlag) defer p.mustFinishNode() @@ -275,6 +272,16 @@ func (p *sourceParser) consumeUseFlag() AstNode { } useNode.MustDecorate(dslshape.NodeUseFlagPredicateName, useFlag) + + // NOTE: we conduct this check in `consumeFlag` rather than at + // the callsite to keep the callsite clean. + // We also do the check after consumption to ensure that the parser continues + // moving past the use expression. + if afterDefinition { + p.emitErrorf("`use` expressions must be declared before any definition") + return useNode + } + return useNode } diff --git a/pkg/composableschemadsl/parser/parser_impl.go b/pkg/composableschemadsl/parser/parser_impl.go index 076e98fb2d..b5f13aa5b9 100644 --- a/pkg/composableschemadsl/parser/parser_impl.go +++ b/pkg/composableschemadsl/parser/parser_impl.go @@ -45,18 +45,18 @@ type commentedLexeme struct { // sourceParser holds the state of the parser. type sourceParser struct { - source input.Source // the name of the input; used only for error reports - input string // the input string itself - lex *lexer.FlaggableLexler // a reference to the lexer used for tokenization - builder NodeBuilder // the builder function for creating AstNode instances - nodes *nodeStack // the stack of the current nodes - currentToken commentedLexeme // the current token - previousToken commentedLexeme // the previous token + source input.Source // the name of the input; used only for error reports + input string // the input string itself + lex *lexer.FlaggableLexer // a reference to the lexer used for tokenization + builder NodeBuilder // the builder function for creating AstNode instances + nodes *nodeStack // the stack of the current nodes + currentToken commentedLexeme // the current token + previousToken commentedLexeme // the previous token } // buildParser returns a new sourceParser instance. func buildParser(lx *lexer.Lexer, builder NodeBuilder, source input.Source, input string) *sourceParser { - l := lexer.NewFlaggableLexler(lx) + l := lexer.NewFlaggableLexer(lx) return &sourceParser{ source: source, input: input, diff --git a/pkg/composableschemadsl/parser/parser_test.go b/pkg/composableschemadsl/parser/parser_test.go index b05762d5e1..79695919ed 100644 --- a/pkg/composableschemadsl/parser/parser_test.go +++ b/pkg/composableschemadsl/parser/parser_test.go @@ -123,6 +123,10 @@ func TestParser(t *testing.T) { {"arrow illegal function test", "arrowillegalfunc"}, {"caveat with keyword parameter test", "caveatwithkeywordparam"}, {"use expiration test", "useexpiration"}, + // NOTE: a duplicate use statement should be an error at the level + // of the compiler, but should not be an error at the level + // of the parser. + {"duplicate use statement test", "duplicate_use_statement"}, {"use expiration keyword test", "useexpirationkeyword"}, {"expiration non-keyword test", "expirationnonkeyword"}, {"invalid use", "invaliduse"}, diff --git a/pkg/composableschemadsl/parser/tests/duplicate_use_statement.zed b/pkg/composableschemadsl/parser/tests/duplicate_use_statement.zed new file mode 100644 index 0000000000..a1e17f5730 --- /dev/null +++ b/pkg/composableschemadsl/parser/tests/duplicate_use_statement.zed @@ -0,0 +1,7 @@ +use expiration +use expiration + +definition resource { + relation viewer: user with expiration + relation editor: user with somecaveat and expiration +} diff --git a/pkg/composableschemadsl/parser/tests/duplicate_use_statement.zed.expected b/pkg/composableschemadsl/parser/tests/duplicate_use_statement.zed.expected new file mode 100644 index 0000000000..4f616df711 --- /dev/null +++ b/pkg/composableschemadsl/parser/tests/duplicate_use_statement.zed.expected @@ -0,0 +1,71 @@ +NodeTypeFile + end-rune = 153 + input-source = duplicate use statement test + start-rune = 0 + child-node => + NodeTypeUseFlag + end-rune = 13 + input-source = duplicate use statement test + start-rune = 0 + use-flag-name = expiration + NodeTypeUseFlag + end-rune = 28 + input-source = duplicate use statement test + start-rune = 15 + use-flag-name = expiration + NodeTypeDefinition + definition-name = resource + end-rune = 152 + input-source = duplicate use statement test + start-rune = 31 + child-node => + NodeTypeRelation + end-rune = 93 + input-source = duplicate use statement test + relation-name = viewer + start-rune = 57 + allowed-types => + NodeTypeTypeReference + end-rune = 93 + input-source = duplicate use statement test + start-rune = 74 + type-ref-type => + NodeTypeSpecificTypeReference + end-rune = 93 + input-source = duplicate use statement test + start-rune = 74 + type-name = user + trait => + NodeTypeTraitReference + end-rune = 93 + input-source = duplicate use statement test + start-rune = 84 + trait-name = expiration + NodeTypeRelation + end-rune = 150 + input-source = duplicate use statement test + relation-name = editor + start-rune = 99 + allowed-types => + NodeTypeTypeReference + end-rune = 150 + input-source = duplicate use statement test + start-rune = 116 + type-ref-type => + NodeTypeSpecificTypeReference + end-rune = 150 + input-source = duplicate use statement test + start-rune = 116 + type-name = user + caveat => + NodeTypeCaveatReference + caveat-name = somecaveat + end-rune = 135 + input-source = duplicate use statement test + start-rune = 126 + trait => + NodeTypeTraitReference + end-rune = 150 + input-source = duplicate use statement test + start-rune = 141 + trait-name = expiration \ No newline at end of file diff --git a/pkg/composableschemadsl/parser/tests/useafterdef.zed.expected b/pkg/composableschemadsl/parser/tests/useafterdef.zed.expected index 29a46c7d2a..1339b5a729 100644 --- a/pkg/composableschemadsl/parser/tests/useafterdef.zed.expected +++ b/pkg/composableschemadsl/parser/tests/useafterdef.zed.expected @@ -1,5 +1,5 @@ NodeTypeFile - end-rune = 22 + end-rune = 37 input-source = use after definition start-rune = 0 child-node => @@ -8,9 +8,14 @@ NodeTypeFile end-rune = 21 input-source = use after definition start-rune = 0 - NodeTypeError - end-rune = 22 - error-message = Unexpected token at root level: TokenTypeKeyword - error-source = use + NodeTypeUseFlag + end-rune = 37 input-source = use after definition - start-rune = 24 \ No newline at end of file + start-rune = 24 + use-flag-name = expiration + child-node => + NodeTypeError + end-rune = 37 + error-message = `use` expressions must be declared before any definition + input-source = use after definition + start-rune = 38 \ No newline at end of file diff --git a/pkg/schemadsl/compiler/compiler.go b/pkg/schemadsl/compiler/compiler.go index 618f7deabb..31fdc0650a 100644 --- a/pkg/schemadsl/compiler/compiler.go +++ b/pkg/schemadsl/compiler/compiler.go @@ -107,7 +107,7 @@ func Compile(schema InputSchema, prefix ObjectPrefixOption, opts ...Option) (*Co return nil, err } - compiled, err := translate(translationContext{ + compiled, err := translate(&translationContext{ objectTypePrefix: cfg.objectTypePrefix, mapper: mapper, schemaString: schema.SchemaString, diff --git a/pkg/schemadsl/compiler/compiler_test.go b/pkg/schemadsl/compiler/compiler_test.go index 114b8ffef4..b37014c9f8 100644 --- a/pkg/schemadsl/compiler/compiler_test.go +++ b/pkg/schemadsl/compiler/compiler_test.go @@ -1006,6 +1006,19 @@ func TestCompile(t *testing.T) { ), }, }, + { + "duplicate use pragmas", + withTenantPrefix, + ` + use expiration + use expiration + + definition simple { + relation viewer: user with expiration + }`, + `found duplicate use flag`, + []SchemaDefinition{}, + }, { "relation with expiration trait and caveat", withTenantPrefix, diff --git a/pkg/schemadsl/compiler/translator.go b/pkg/schemadsl/compiler/translator.go index 3ad5186b62..ebdfc402d9 100644 --- a/pkg/schemadsl/compiler/translator.go +++ b/pkg/schemadsl/compiler/translator.go @@ -24,9 +24,10 @@ type translationContext struct { schemaString string skipValidate bool allowedFlags []string + enabledFlags []string } -func (tctx translationContext) prefixedPath(definitionName string) (string, error) { +func (tctx *translationContext) prefixedPath(definitionName string) (string, error) { var prefix, name string if err := stringz.SplitInto(definitionName, "/", &prefix, &name); err != nil { if tctx.objectTypePrefix == nil { @@ -45,7 +46,7 @@ func (tctx translationContext) prefixedPath(definitionName string) (string, erro const Ellipsis = "..." -func translate(tctx translationContext, root *dslNode) (*CompiledSchema, error) { +func translate(tctx *translationContext, root *dslNode) (*CompiledSchema, error) { orderedDefinitions := make([]SchemaDefinition, 0, len(root.GetChildren())) var objectDefinitions []*core.NamespaceDefinition var caveatDefinitions []*core.CaveatDefinition @@ -57,7 +58,10 @@ func translate(tctx translationContext, root *dslNode) (*CompiledSchema, error) switch definitionNode.GetType() { case dslshape.NodeTypeUseFlag: - // Skip the flags. + err := translateUseFlag(tctx, definitionNode) + if err != nil { + return nil, err + } continue case dslshape.NodeTypeCaveatDefinition: @@ -95,7 +99,7 @@ func translate(tctx translationContext, root *dslNode) (*CompiledSchema, error) }, nil } -func translateCaveatDefinition(tctx translationContext, defNode *dslNode) (*core.CaveatDefinition, error) { +func translateCaveatDefinition(tctx *translationContext, defNode *dslNode) (*core.CaveatDefinition, error) { definitionName, err := defNode.GetString(dslshape.NodeCaveatDefinitionPredicateName) if err != nil { return nil, defNode.WithSourceErrorf(definitionName, "invalid definition name: %w", err) @@ -177,7 +181,7 @@ func translateCaveatDefinition(tctx translationContext, defNode *dslNode) (*core return def, nil } -func translateCaveatTypeReference(tctx translationContext, typeRefNode *dslNode) (*caveattypes.VariableType, error) { +func translateCaveatTypeReference(tctx *translationContext, typeRefNode *dslNode) (*caveattypes.VariableType, error) { typeName, err := typeRefNode.GetString(dslshape.NodeCaveatTypeReferencePredicateType) if err != nil { return nil, typeRefNode.WithSourceErrorf(typeName, "invalid type name: %w", err) @@ -201,7 +205,7 @@ func translateCaveatTypeReference(tctx translationContext, typeRefNode *dslNode) return constructedType, nil } -func translateObjectDefinition(tctx translationContext, defNode *dslNode) (*core.NamespaceDefinition, error) { +func translateObjectDefinition(tctx *translationContext, defNode *dslNode) (*core.NamespaceDefinition, error) { definitionName, err := defNode.GetString(dslshape.NodeDefinitionPredicateName) if err != nil { return nil, defNode.WithSourceErrorf(definitionName, "invalid definition name: %w", err) @@ -300,7 +304,7 @@ func normalizeComment(value string) string { return strings.Join(lines, "\n") } -func translateRelationOrPermission(tctx translationContext, relOrPermNode *dslNode) (*core.Relation, error) { +func translateRelationOrPermission(tctx *translationContext, relOrPermNode *dslNode) (*core.Relation, error) { switch relOrPermNode.GetType() { case dslshape.NodeTypeRelation: rel, err := translateRelation(tctx, relOrPermNode) @@ -325,7 +329,7 @@ func translateRelationOrPermission(tctx translationContext, relOrPermNode *dslNo } } -func translateRelation(tctx translationContext, relationNode *dslNode) (*core.Relation, error) { +func translateRelation(tctx *translationContext, relationNode *dslNode) (*core.Relation, error) { relationName, err := relationNode.GetString(dslshape.NodePredicateName) if err != nil { return nil, relationNode.Errorf("invalid relation name: %w", err) @@ -355,7 +359,7 @@ func translateRelation(tctx translationContext, relationNode *dslNode) (*core.Re return relation, nil } -func translatePermission(tctx translationContext, permissionNode *dslNode) (*core.Relation, error) { +func translatePermission(tctx *translationContext, permissionNode *dslNode) (*core.Relation, error) { permissionName, err := permissionNode.GetString(dslshape.NodePredicateName) if err != nil { return nil, permissionNode.Errorf("invalid permission name: %w", err) @@ -385,7 +389,7 @@ func translatePermission(tctx translationContext, permissionNode *dslNode) (*cor return permission, nil } -func translateBinary(tctx translationContext, expressionNode *dslNode) (*core.SetOperation_Child, *core.SetOperation_Child, error) { +func translateBinary(tctx *translationContext, expressionNode *dslNode) (*core.SetOperation_Child, *core.SetOperation_Child, error) { leftChild, err := expressionNode.Lookup(dslshape.NodeExpressionPredicateLeftExpr) if err != nil { return nil, nil, err @@ -409,7 +413,7 @@ func translateBinary(tctx translationContext, expressionNode *dslNode) (*core.Se return leftOperation, rightOperation, nil } -func translateExpression(tctx translationContext, expressionNode *dslNode) (*core.UsersetRewrite, error) { +func translateExpression(tctx *translationContext, expressionNode *dslNode) (*core.UsersetRewrite, error) { translated, err := translateExpressionDirect(tctx, expressionNode) if err != nil { return translated, err @@ -437,7 +441,7 @@ func collapseOps(op *core.SetOperation_Child, handler func(rewrite *core.Userset return collapsed } -func translateExpressionDirect(tctx translationContext, expressionNode *dslNode) (*core.UsersetRewrite, error) { +func translateExpressionDirect(tctx *translationContext, expressionNode *dslNode) (*core.UsersetRewrite, error) { // For union and intersection, we collapse a tree of binary operations into a flat list containing child // operations of the *same* type. translate := func( @@ -483,7 +487,7 @@ func translateExpressionDirect(tctx translationContext, expressionNode *dslNode) } } -func translateExpressionOperation(tctx translationContext, expressionOpNode *dslNode) (*core.SetOperation_Child, error) { +func translateExpressionOperation(tctx *translationContext, expressionOpNode *dslNode) (*core.SetOperation_Child, error) { translated, err := translateExpressionOperationDirect(tctx, expressionOpNode) if err != nil { return translated, err @@ -493,7 +497,7 @@ func translateExpressionOperation(tctx translationContext, expressionOpNode *dsl return translated, nil } -func translateExpressionOperationDirect(tctx translationContext, expressionOpNode *dslNode) (*core.SetOperation_Child, error) { +func translateExpressionOperationDirect(tctx *translationContext, expressionOpNode *dslNode) (*core.SetOperation_Child, error) { switch expressionOpNode.GetType() { case dslshape.NodeTypeIdentifier: referencedRelationName, err := expressionOpNode.GetString(dslshape.NodeIdentiferPredicateValue) @@ -560,7 +564,7 @@ func translateExpressionOperationDirect(tctx translationContext, expressionOpNod } } -func translateAllowedRelations(tctx translationContext, typeRefNode *dslNode) ([]*core.AllowedRelation, error) { +func translateAllowedRelations(tctx *translationContext, typeRefNode *dslNode) ([]*core.AllowedRelation, error) { switch typeRefNode.GetType() { case dslshape.NodeTypeTypeReference: references := []*core.AllowedRelation{} @@ -586,7 +590,7 @@ func translateAllowedRelations(tctx translationContext, typeRefNode *dslNode) ([ } } -func translateSpecificTypeReference(tctx translationContext, typeRefNode *dslNode) (*core.AllowedRelation, error) { +func translateSpecificTypeReference(tctx *translationContext, typeRefNode *dslNode) (*core.AllowedRelation, error) { typePath, err := typeRefNode.GetString(dslshape.NodeSpecificReferencePredicateType) if err != nil { return nil, typeRefNode.Errorf("invalid type name: %w", err) @@ -669,7 +673,7 @@ func translateSpecificTypeReference(tctx translationContext, typeRefNode *dslNod return ref, nil } -func addWithCaveats(tctx translationContext, typeRefNode *dslNode, ref *core.AllowedRelation) error { +func addWithCaveats(tctx *translationContext, typeRefNode *dslNode, ref *core.AllowedRelation) error { caveats := typeRefNode.List(dslshape.NodeSpecificReferencePredicateCaveat) if len(caveats) == 0 { return nil @@ -694,3 +698,16 @@ func addWithCaveats(tctx translationContext, typeRefNode *dslNode, ref *core.All } return nil } + +// Translate use node and add flag to list of enabled flags +func translateUseFlag(tctx *translationContext, useFlagNode *dslNode) error { + flagName, err := useFlagNode.GetString(dslshape.NodeUseFlagPredicateName) + if err != nil { + return err + } + if slices.Contains(tctx.enabledFlags, flagName) { + return useFlagNode.Errorf("found duplicate use flag: %s", flagName) + } + tctx.enabledFlags = append(tctx.enabledFlags, flagName) + return nil +} diff --git a/pkg/schemadsl/lexer/flaggablelexer.go b/pkg/schemadsl/lexer/flaggablelexer.go index d0700337bf..1ae99af777 100644 --- a/pkg/schemadsl/lexer/flaggablelexer.go +++ b/pkg/schemadsl/lexer/flaggablelexer.go @@ -1,28 +1,28 @@ package lexer -// FlaggableLexler wraps a lexer, automatically translating tokens based on flags, if any. -type FlaggableLexler struct { +// FlaggableLexer wraps a lexer, automatically translating tokens based on flags, if any. +type FlaggableLexer struct { lex *Lexer // a reference to the lexer used for tokenization enabledFlags map[string]transformer // flags that are enabled seenDefinition bool afterUseIdentifier bool } -// NewFlaggableLexler returns a new FlaggableLexler for the given lexer. -func NewFlaggableLexler(lex *Lexer) *FlaggableLexler { - return &FlaggableLexler{ +// NewFlaggableLexer returns a new FlaggableLexer for the given lexer. +func NewFlaggableLexer(lex *Lexer) *FlaggableLexer { + return &FlaggableLexer{ lex: lex, enabledFlags: map[string]transformer{}, } } // Close stops the lexer from running. -func (l *FlaggableLexler) Close() { +func (l *FlaggableLexer) Close() { l.lex.Close() } // NextToken returns the next token found in the lexer. -func (l *FlaggableLexler) NextToken() Lexeme { +func (l *FlaggableLexer) NextToken() Lexeme { nextToken := l.lex.nextToken() // Look for `use somefeature` diff --git a/pkg/schemadsl/lexer/flaggablelexer_test.go b/pkg/schemadsl/lexer/flaggablelexer_test.go index b2da4c4a8e..3ecc89bf11 100644 --- a/pkg/schemadsl/lexer/flaggablelexer_test.go +++ b/pkg/schemadsl/lexer/flaggablelexer_test.go @@ -62,7 +62,7 @@ func TestFlaggableLexer(t *testing.T) { } func performFlaggedLex(t *lexerTest) (tokens []Lexeme) { - lexer := NewFlaggableLexler(Lex(input.Source(t.name), t.input)) + lexer := NewFlaggableLexer(Lex(input.Source(t.name), t.input)) for { token := lexer.NextToken() tokens = append(tokens, token) diff --git a/pkg/schemadsl/parser/parser.go b/pkg/schemadsl/parser/parser.go index 0b1ad2b69c..100ad9bf4f 100644 --- a/pkg/schemadsl/parser/parser.go +++ b/pkg/schemadsl/parser/parser.go @@ -48,12 +48,6 @@ Loop: break Loop } - if !hasSeenDefinition { - if p.isIdentifier("use") { - rootNode.Connect(dslshape.NodePredicateChild, p.consumeUseFlag()) - } - } - // Consume a statement terminator if one was found. p.tryConsumeStatementTerminator() @@ -66,6 +60,9 @@ Loop: // caveat somecaveat (...) { ... } switch { + case p.isIdentifier("use"): + rootNode.Connect(dslshape.NodePredicateChild, p.consumeUseFlag(hasSeenDefinition)) + case p.isKeyword("definition"): hasSeenDefinition = true rootNode.Connect(dslshape.NodePredicateChild, p.consumeDefinition()) @@ -245,7 +242,7 @@ func (p *sourceParser) consumeCaveatTypeReference() AstNode { // consumeUseFlag attempts to consume a use flag. // ``` use flagname ``` -func (p *sourceParser) consumeUseFlag() AstNode { +func (p *sourceParser) consumeUseFlag(afterDefinition bool) AstNode { useNode := p.startNode(dslshape.NodeTypeUseFlag) defer p.mustFinishNode() @@ -269,6 +266,16 @@ func (p *sourceParser) consumeUseFlag() AstNode { } useNode.MustDecorate(dslshape.NodeUseFlagPredicateName, useFlag) + + // NOTE: we conduct this check in `consumeFlag` rather than at + // the callsite to keep the callsite clean. + // We also do the check after consumption to ensure that the parser continues + // moving past the use expression. + if afterDefinition { + p.emitErrorf("`use` expressions must be declared before any definition") + return useNode + } + return useNode } diff --git a/pkg/schemadsl/parser/parser_impl.go b/pkg/schemadsl/parser/parser_impl.go index c63aec34ed..1009b6b557 100644 --- a/pkg/schemadsl/parser/parser_impl.go +++ b/pkg/schemadsl/parser/parser_impl.go @@ -44,18 +44,18 @@ type commentedLexeme struct { // sourceParser holds the state of the parser. type sourceParser struct { - source input.Source // the name of the input; used only for error reports - input string // the input string itself - lex *lexer.FlaggableLexler // a reference to the lexer used for tokenization - builder NodeBuilder // the builder function for creating AstNode instances - nodes *nodeStack // the stack of the current nodes - currentToken commentedLexeme // the current token - previousToken commentedLexeme // the previous token + source input.Source // the name of the input; used only for error reports + input string // the input string itself + lex *lexer.FlaggableLexer // a reference to the lexer used for tokenization + builder NodeBuilder // the builder function for creating AstNode instances + nodes *nodeStack // the stack of the current nodes + currentToken commentedLexeme // the current token + previousToken commentedLexeme // the previous token } // buildParser returns a new sourceParser instance. func buildParser(lx *lexer.Lexer, builder NodeBuilder, source input.Source, input string) *sourceParser { - l := lexer.NewFlaggableLexler(lx) + l := lexer.NewFlaggableLexer(lx) return &sourceParser{ source: source, input: input, diff --git a/pkg/schemadsl/parser/tests/duplicate_use_statement.zed b/pkg/schemadsl/parser/tests/duplicate_use_statement.zed new file mode 100644 index 0000000000..a1e17f5730 --- /dev/null +++ b/pkg/schemadsl/parser/tests/duplicate_use_statement.zed @@ -0,0 +1,7 @@ +use expiration +use expiration + +definition resource { + relation viewer: user with expiration + relation editor: user with somecaveat and expiration +} diff --git a/pkg/schemadsl/parser/tests/duplicate_use_statement.zed.expected b/pkg/schemadsl/parser/tests/duplicate_use_statement.zed.expected new file mode 100644 index 0000000000..4f616df711 --- /dev/null +++ b/pkg/schemadsl/parser/tests/duplicate_use_statement.zed.expected @@ -0,0 +1,71 @@ +NodeTypeFile + end-rune = 153 + input-source = duplicate use statement test + start-rune = 0 + child-node => + NodeTypeUseFlag + end-rune = 13 + input-source = duplicate use statement test + start-rune = 0 + use-flag-name = expiration + NodeTypeUseFlag + end-rune = 28 + input-source = duplicate use statement test + start-rune = 15 + use-flag-name = expiration + NodeTypeDefinition + definition-name = resource + end-rune = 152 + input-source = duplicate use statement test + start-rune = 31 + child-node => + NodeTypeRelation + end-rune = 93 + input-source = duplicate use statement test + relation-name = viewer + start-rune = 57 + allowed-types => + NodeTypeTypeReference + end-rune = 93 + input-source = duplicate use statement test + start-rune = 74 + type-ref-type => + NodeTypeSpecificTypeReference + end-rune = 93 + input-source = duplicate use statement test + start-rune = 74 + type-name = user + trait => + NodeTypeTraitReference + end-rune = 93 + input-source = duplicate use statement test + start-rune = 84 + trait-name = expiration + NodeTypeRelation + end-rune = 150 + input-source = duplicate use statement test + relation-name = editor + start-rune = 99 + allowed-types => + NodeTypeTypeReference + end-rune = 150 + input-source = duplicate use statement test + start-rune = 116 + type-ref-type => + NodeTypeSpecificTypeReference + end-rune = 150 + input-source = duplicate use statement test + start-rune = 116 + type-name = user + caveat => + NodeTypeCaveatReference + caveat-name = somecaveat + end-rune = 135 + input-source = duplicate use statement test + start-rune = 126 + trait => + NodeTypeTraitReference + end-rune = 150 + input-source = duplicate use statement test + start-rune = 141 + trait-name = expiration \ No newline at end of file diff --git a/pkg/schemadsl/parser/tests/useafterdef.zed.expected b/pkg/schemadsl/parser/tests/useafterdef.zed.expected index 7154b75c12..1339b5a729 100644 --- a/pkg/schemadsl/parser/tests/useafterdef.zed.expected +++ b/pkg/schemadsl/parser/tests/useafterdef.zed.expected @@ -1,5 +1,5 @@ NodeTypeFile - end-rune = 22 + end-rune = 37 input-source = use after definition start-rune = 0 child-node => @@ -8,9 +8,14 @@ NodeTypeFile end-rune = 21 input-source = use after definition start-rune = 0 - NodeTypeError - end-rune = 22 - error-message = Unexpected token at root level: TokenTypeIdentifier - error-source = use + NodeTypeUseFlag + end-rune = 37 input-source = use after definition - start-rune = 24 \ No newline at end of file + start-rune = 24 + use-flag-name = expiration + child-node => + NodeTypeError + end-rune = 37 + error-message = `use` expressions must be declared before any definition + input-source = use after definition + start-rune = 38 \ No newline at end of file