Skip to content

Commit

Permalink
Port expiration compiler changes into composableschemadsl
Browse files Browse the repository at this point in the history
  • Loading branch information
tstirrat15 committed Feb 10, 2025
1 parent 5e7c223 commit 96f054c
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 14 deletions.
21 changes: 20 additions & 1 deletion pkg/composableschemadsl/compiler/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"

"google.golang.org/protobuf/proto"
"k8s.io/utils/strings/slices"

"github.com/authzed/spicedb/pkg/composableschemadsl/dslshape"
"github.com/authzed/spicedb/pkg/composableschemadsl/input"
Expand Down Expand Up @@ -53,6 +54,7 @@ func (cs CompiledSchema) SourcePositionToRunePosition(source input.Source, posit
type config struct {
skipValidation bool
objectTypePrefix *string
allowedFlags []string
// In an import context, this is the folder containing
// the importing schema (as opposed to imported schemas)
sourceFolder string
Expand All @@ -76,6 +78,16 @@ func AllowUnprefixedObjectType() ObjectPrefixOption {
return func(cfg *config) { cfg.objectTypePrefix = new(string) }
}

const expirationFlag = "expiration"

func DisallowExpirationFlag() Option {
return func(cfg *config) {
cfg.allowedFlags = slices.Filter([]string{}, cfg.allowedFlags, func(s string) bool {
return s != expirationFlag
})
}
}

// Config that supplies the root source folder for compilation. Required
// for relative import syntax to work properly.
func SourceFolder(sourceFolder string) Option {
Expand All @@ -88,7 +100,13 @@ type ObjectPrefixOption func(*config)

// Compile compilers the input schema into a set of namespace definition protos.
func Compile(schema InputSchema, prefix ObjectPrefixOption, opts ...Option) (*CompiledSchema, error) {
cfg := &config{}
cfg := &config{
allowedFlags: make([]string, 0, 1),
}

// Enable `expiration` flag by default.
cfg.allowedFlags = append(cfg.allowedFlags, expirationFlag)

prefix(cfg) // required option

for _, fn := range opts {
Expand Down Expand Up @@ -116,6 +134,7 @@ func Compile(schema InputSchema, prefix ObjectPrefixOption, opts ...Option) (*Co
mapper: mapper,
schemaString: schema.SchemaString,
skipValidate: cfg.skipValidation,
allowedFlags: cfg.allowedFlags,
existingNames: mapz.NewSet[string](),
compiledPartials: make(map[string][]*core.Relation),
unresolvedPartials: mapz.NewMultiMap[string, *dslNode](),
Expand Down
49 changes: 49 additions & 0 deletions pkg/composableschemadsl/compiler/compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1206,6 +1206,55 @@ func TestCompile(t *testing.T) {
),
},
},
{
"relation with expiration caveat",
withTenantPrefix,
`definition simple {
relation viewer: user with expiration
}`,
"",
[]SchemaDefinition{
namespace.Namespace("sometenant/simple",
namespace.MustRelation("viewer", nil,
namespace.AllowedRelationWithCaveat("sometenant/user", "...", namespace.AllowedCaveat("sometenant/expiration")),
),
),
},
},
{
"relation with expiration trait",
withTenantPrefix,
`use expiration
definition simple {
relation viewer: user with expiration
}`,
"",
[]SchemaDefinition{
namespace.Namespace("sometenant/simple",
namespace.MustRelation("viewer", nil,
namespace.AllowedRelationWithExpiration("sometenant/user", "..."),
),
),
},
},
{
"relation with expiration trait and caveat",
withTenantPrefix,
`use expiration
definition simple {
relation viewer: user with somecaveat and expiration
}`,
"",
[]SchemaDefinition{
namespace.Namespace("sometenant/simple",
namespace.MustRelation("viewer", nil,
namespace.AllowedRelationWithCaveatAndExpiration("sometenant/user", "...", namespace.AllowedCaveat("sometenant/somecaveat")),
),
),
},
},
}

for _, test := range tests {
Expand Down
26 changes: 26 additions & 0 deletions pkg/composableschemadsl/compiler/translator.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"container/list"
"fmt"
"path/filepath"
"slices"
"strings"

"github.com/ccoveille/go-safecast"
Expand All @@ -25,6 +26,7 @@ type translationContext struct {
mapper input.PositionMapper
schemaString string
skipValidate bool
allowedFlags []string
existingNames *mapz.Set[string]
// The mapping of partial name -> relations represented by the partial
compiledPartials map[string][]*core.Relation
Expand Down Expand Up @@ -71,6 +73,11 @@ func translate(tctx translationContext, root *dslNode) (*CompiledSchema, error)

for _, topLevelNode := range root.GetChildren() {
switch topLevelNode.GetType() {

case dslshape.NodeTypeUseFlag:
// Skip the flags.
continue

case dslshape.NodeTypeCaveatDefinition:
log.Trace().Msg("adding caveat definition")
// TODO: Maybe refactor these in terms of a generic function?
Expand Down Expand Up @@ -680,11 +687,30 @@ func translateSpecificTypeReference(tctx translationContext, typeRefNode *dslNod
},
}

// Add the caveat(s), if any.
err = addWithCaveats(tctx, typeRefNode, ref)
if err != nil {
return nil, typeRefNode.Errorf("invalid caveat: %w", err)
}

// Add the expiration trait, if any.
if traitNode, err := typeRefNode.Lookup(dslshape.NodeSpecificReferencePredicateTrait); err == nil {
traitName, err := traitNode.GetString(dslshape.NodeTraitPredicateTrait)
if err != nil {
return nil, typeRefNode.Errorf("invalid trait: %w", err)
}

if traitName != "expiration" {
return nil, typeRefNode.Errorf("invalid trait: %s", traitName)
}

if !slices.Contains(tctx.allowedFlags, "expiration") {
return nil, typeRefNode.Errorf("expiration trait is not allowed")
}

ref.RequiredExpiration = &core.ExpirationTrait{}
}

if !tctx.skipValidate {
if err := ref.Validate(); err != nil {
return nil, typeRefNode.Errorf("invalid type relation: %w", err)
Expand Down
45 changes: 40 additions & 5 deletions pkg/composableschemadsl/generator/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/authzed/spicedb/pkg/caveats"
caveattypes "github.com/authzed/spicedb/pkg/caveats/types"
"github.com/authzed/spicedb/pkg/composableschemadsl/compiler"
"github.com/authzed/spicedb/pkg/genutil/mapz"
"github.com/authzed/spicedb/pkg/graph"
"github.com/authzed/spicedb/pkg/namespace"
core "github.com/authzed/spicedb/pkg/proto/core/v1"
Expand All @@ -26,6 +27,8 @@ const MaxSingleLineCommentLength = 70 // 80 - the comment parts and some padding
// GenerateSchema generates a DSL view of the given schema.
func GenerateSchema(definitions []compiler.SchemaDefinition) (string, bool, error) {
generated := make([]string, 0, len(definitions))
flags := mapz.NewSet[string]()

result := true
for _, definition := range definitions {
switch def := definition.(type) {
Expand All @@ -39,19 +42,29 @@ func GenerateSchema(definitions []compiler.SchemaDefinition) (string, bool, erro
generated = append(generated, generatedCaveat)

case *core.NamespaceDefinition:
generatedSchema, ok, err := GenerateSource(def)
generatedSchema, defFlags, ok, err := generateDefinitionSource(def)
if err != nil {
return "", false, err
}

result = result && ok
generated = append(generated, generatedSchema)
flags.Extend(defFlags)

default:
return "", false, spiceerrors.MustBugf("unknown type of definition %T in GenerateSchema", def)
}
}

if !flags.IsEmpty() {
flagsSlice := flags.AsSlice()
sort.Strings(flagsSlice)

for _, flag := range flagsSlice {
generated = append([]string{"use " + flag}, generated...)
}
}

return strings.Join(generated, "\n\n"), result, nil
}

Expand All @@ -74,19 +87,25 @@ func GenerateCaveatSource(caveat *core.CaveatDefinition) (string, bool, error) {

// GenerateSource generates a DSL view of the given namespace definition.
func GenerateSource(namespace *core.NamespaceDefinition) (string, bool, error) {
source, _, ok, err := generateDefinitionSource(namespace)
return source, ok, err
}

func generateDefinitionSource(namespace *core.NamespaceDefinition) (string, []string, bool, error) {
generator := &sourceGenerator{
indentationLevel: 0,
hasNewline: true,
hasBlankline: true,
hasNewScope: true,
flags: mapz.NewSet[string](),
}

err := generator.emitNamespace(namespace)
if err != nil {
return "", false, err
return "", nil, false, err
}

return generator.buf.String(), !generator.hasIssue, nil
return generator.buf.String(), generator.flags.AsSlice(), !generator.hasIssue, nil
}

// GenerateRelationSource generates a DSL view of the given relation definition.
Expand Down Expand Up @@ -237,9 +256,25 @@ func (sg *sourceGenerator) emitAllowedRelation(allowedRelation *core.AllowedRela
if allowedRelation.GetPublicWildcard() != nil {
sg.append(":*")
}
if allowedRelation.GetRequiredCaveat() != nil {

hasExpirationTrait := allowedRelation.GetRequiredExpiration() != nil
hasCaveat := allowedRelation.GetRequiredCaveat() != nil

if hasExpirationTrait || hasCaveat {
sg.append(" with ")
sg.append(allowedRelation.RequiredCaveat.CaveatName)
if hasCaveat {
sg.append(allowedRelation.RequiredCaveat.CaveatName)
}

if hasExpirationTrait {
sg.flags.Add("expiration")

if hasCaveat {
sg.append(" and ")
}

sg.append("expiration")
}
}
}

Expand Down
17 changes: 10 additions & 7 deletions pkg/composableschemadsl/generator/generator_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@ package generator

import (
"strings"

"github.com/authzed/spicedb/pkg/genutil/mapz"
)

type sourceGenerator struct {
buf strings.Builder // The buffer for the new source code.
indentationLevel int // The current indentation level.
hasNewline bool // Whether there is a newline at the end of the buffer.
hasBlankline bool // Whether there is a blank line at the end of the buffer.
hasIssue bool // Whether there is a translation issue.
hasNewScope bool // Whether there is a new scope at the end of the buffer.
existingLineLength int // Length of the existing line.
buf strings.Builder // The buffer for the new source code.
indentationLevel int // The current indentation level.
hasNewline bool // Whether there is a newline at the end of the buffer.
hasBlankline bool // Whether there is a blank line at the end of the buffer.
hasIssue bool // Whether there is a translation issue.
hasNewScope bool // Whether there is a new scope at the end of the buffer.
existingLineLength int // Length of the existing line.
flags *mapz.Set[string] // The flags added while generating.
}

// ensureBlankLineOrNewScope ensures that there is a blank line or new scope at the tail of the buffer. If not,
Expand Down
36 changes: 35 additions & 1 deletion pkg/composableschemadsl/generator/generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,6 @@ definition foos/test {
relation somerel: foos/bars
}`,
},

{
"full example",
`
Expand Down Expand Up @@ -360,6 +359,41 @@ definition foos/document {
}`,
`definition document {
permission first = rela->relb + relc.any(reld) + rele.all(relf)
}`,
},
{
"expiration caveat",
`definition document{
relation viewer: user with expiration
}`,
`definition document {
relation viewer: user with expiration
}`,
},
{
"expiration trait",
`use expiration
definition document{
relation viewer: user with expiration
relation editor: user with somecaveat and expiration
}`,
`use expiration
definition document {
relation viewer: user with expiration
relation editor: user with somecaveat and expiration
}`,
},
{
"unused expiration flag",
`use expiration
definition document{
relation viewer: user
}`,
`definition document {
relation viewer: user
}`,
},
}
Expand Down

0 comments on commit 96f054c

Please sign in to comment.