Skip to content

Commit

Permalink
Port expiration compiler changes into composableschemadsl
Browse files Browse the repository at this point in the history
  • Loading branch information
tstirrat15 committed Feb 10, 2025
1 parent 5e7c223 commit 22e351c
Show file tree
Hide file tree
Showing 24 changed files with 500 additions and 99 deletions.
23 changes: 21 additions & 2 deletions pkg/composableschemadsl/compiler/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"

"google.golang.org/protobuf/proto"
"k8s.io/utils/strings/slices"

"github.com/authzed/spicedb/pkg/composableschemadsl/dslshape"
"github.com/authzed/spicedb/pkg/composableschemadsl/input"
Expand Down Expand Up @@ -53,6 +54,7 @@ func (cs CompiledSchema) SourcePositionToRunePosition(source input.Source, posit
type config struct {
skipValidation bool
objectTypePrefix *string
allowedFlags []string
// In an import context, this is the folder containing
// the importing schema (as opposed to imported schemas)
sourceFolder string
Expand All @@ -76,6 +78,16 @@ func AllowUnprefixedObjectType() ObjectPrefixOption {
return func(cfg *config) { cfg.objectTypePrefix = new(string) }
}

const expirationFlag = "expiration"

func DisallowExpirationFlag() Option {
return func(cfg *config) {
cfg.allowedFlags = slices.Filter([]string{}, cfg.allowedFlags, func(s string) bool {
return s != expirationFlag
})
}
}

// Config that supplies the root source folder for compilation. Required
// for relative import syntax to work properly.
func SourceFolder(sourceFolder string) Option {
Expand All @@ -88,7 +100,13 @@ type ObjectPrefixOption func(*config)

// Compile compilers the input schema into a set of namespace definition protos.
func Compile(schema InputSchema, prefix ObjectPrefixOption, opts ...Option) (*CompiledSchema, error) {
cfg := &config{}
cfg := &config{
allowedFlags: make([]string, 0, 1),
}

// Enable `expiration` flag by default.
cfg.allowedFlags = append(cfg.allowedFlags, expirationFlag)

prefix(cfg) // required option

for _, fn := range opts {
Expand All @@ -111,11 +129,12 @@ func Compile(schema InputSchema, prefix ObjectPrefixOption, opts ...Option) (*Co
return nil, err
}

compiled, err := translate(translationContext{
compiled, err := translate(&translationContext{
objectTypePrefix: cfg.objectTypePrefix,
mapper: mapper,
schemaString: schema.SchemaString,
skipValidate: cfg.skipValidation,
allowedFlags: cfg.allowedFlags,
existingNames: mapz.NewSet[string](),
compiledPartials: make(map[string][]*core.Relation),
unresolvedPartials: mapz.NewMultiMap[string, *dslNode](),
Expand Down
57 changes: 57 additions & 0 deletions pkg/composableschemadsl/compiler/compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1206,6 +1206,63 @@ func TestCompile(t *testing.T) {
),
},
},
{
"relation with expiration trait",
withTenantPrefix,
`use expiration
definition simple {
relation viewer: user with expiration
}`,
"",
[]SchemaDefinition{
namespace.Namespace("sometenant/simple",
namespace.MustRelation("viewer", nil,
namespace.AllowedRelationWithExpiration("sometenant/user", "..."),
),
),
},
},
{
"duplicate use pragmas",
withTenantPrefix,
`
use expiration
use expiration
definition simple {
relation viewer: user with expiration
}`,
`found duplicate use flag`,
[]SchemaDefinition{},
},
{
"expiration use without use expiration",
withTenantPrefix,
`
definition simple {
relation viewer: user with expiration
}`,
`expiration flag is not enabled`,
[]SchemaDefinition{},
},
{
"relation with expiration trait and caveat",
withTenantPrefix,
`use expiration
definition simple {
relation viewer: user with somecaveat and expiration
}`,
"",
[]SchemaDefinition{
namespace.Namespace("sometenant/simple",
namespace.MustRelation("viewer", nil,
namespace.AllowedRelationWithCaveatAndExpiration("sometenant/user", "...", namespace.AllowedCaveat("sometenant/somecaveat")),
),
),
},
},
}

for _, test := range tests {
Expand Down
85 changes: 65 additions & 20 deletions pkg/composableschemadsl/compiler/translator.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"container/list"
"fmt"
"path/filepath"
"slices"
"strings"

"github.com/ccoveille/go-safecast"
Expand All @@ -25,6 +26,8 @@ type translationContext struct {
mapper input.PositionMapper
schemaString string
skipValidate bool
allowedFlags []string
enabledFlags []string
existingNames *mapz.Set[string]
// The mapping of partial name -> relations represented by the partial
compiledPartials map[string][]*core.Relation
Expand All @@ -33,7 +36,7 @@ type translationContext struct {
unresolvedPartials *mapz.MultiMap[string, *dslNode]
}

func (tctx translationContext) prefixedPath(definitionName string) (string, error) {
func (tctx *translationContext) prefixedPath(definitionName string) (string, error) {
var prefix, name string
if err := stringz.SplitInto(definitionName, "/", &prefix, &name); err != nil {
if tctx.objectTypePrefix == nil {
Expand All @@ -52,7 +55,7 @@ func (tctx translationContext) prefixedPath(definitionName string) (string, erro

const Ellipsis = "..."

func translate(tctx translationContext, root *dslNode) (*CompiledSchema, error) {
func translate(tctx *translationContext, root *dslNode) (*CompiledSchema, error) {
orderedDefinitions := make([]SchemaDefinition, 0, len(root.GetChildren()))
var objectDefinitions []*core.NamespaceDefinition
var caveatDefinitions []*core.CaveatDefinition
Expand All @@ -71,6 +74,12 @@ func translate(tctx translationContext, root *dslNode) (*CompiledSchema, error)

for _, topLevelNode := range root.GetChildren() {
switch topLevelNode.GetType() {
case dslshape.NodeTypeUseFlag:
err := translateUseFlag(tctx, topLevelNode)
if err != nil {
return nil, err
}

case dslshape.NodeTypeCaveatDefinition:
log.Trace().Msg("adding caveat definition")
// TODO: Maybe refactor these in terms of a generic function?
Expand Down Expand Up @@ -115,7 +124,7 @@ func translate(tctx translationContext, root *dslNode) (*CompiledSchema, error)
}, nil
}

func translateCaveatDefinition(tctx translationContext, defNode *dslNode) (*core.CaveatDefinition, error) {
func translateCaveatDefinition(tctx *translationContext, defNode *dslNode) (*core.CaveatDefinition, error) {
definitionName, err := defNode.GetString(dslshape.NodeCaveatDefinitionPredicateName)
if err != nil {
return nil, defNode.WithSourceErrorf(definitionName, "invalid definition name: %w", err)
Expand Down Expand Up @@ -197,7 +206,7 @@ func translateCaveatDefinition(tctx translationContext, defNode *dslNode) (*core
return def, nil
}

func translateCaveatTypeReference(tctx translationContext, typeRefNode *dslNode) (*caveattypes.VariableType, error) {
func translateCaveatTypeReference(tctx *translationContext, typeRefNode *dslNode) (*caveattypes.VariableType, error) {
typeName, err := typeRefNode.GetString(dslshape.NodeCaveatTypeReferencePredicateType)
if err != nil {
return nil, typeRefNode.WithSourceErrorf(typeName, "invalid type name: %w", err)
Expand All @@ -221,7 +230,7 @@ func translateCaveatTypeReference(tctx translationContext, typeRefNode *dslNode)
return constructedType, nil
}

func translateObjectDefinition(tctx translationContext, defNode *dslNode) (*core.NamespaceDefinition, error) {
func translateObjectDefinition(tctx *translationContext, defNode *dslNode) (*core.NamespaceDefinition, error) {
definitionName, err := defNode.GetString(dslshape.NodeDefinitionPredicateName)
if err != nil {
return nil, defNode.WithSourceErrorf(definitionName, "invalid definition name: %w", err)
Expand Down Expand Up @@ -269,7 +278,7 @@ func translateObjectDefinition(tctx translationContext, defNode *dslNode) (*core
// A value of true treats that as an error state, since all partials should be resolved when translating definitions,
// where the false value returns the name of the partial for collection for future processing
// when translating partials.
func translateRelationsAndPermissions(tctx translationContext, astNode *dslNode, errorOnMissingReference bool) ([]*core.Relation, string, error) {
func translateRelationsAndPermissions(tctx *translationContext, astNode *dslNode, errorOnMissingReference bool) ([]*core.Relation, string, error) {
relationsAndPermissions := []*core.Relation{}
for _, definitionChildNode := range astNode.GetChildren() {
if definitionChildNode.GetType() == dslshape.NodeTypeComment {
Expand Down Expand Up @@ -345,7 +354,7 @@ func normalizeComment(value string) string {
return strings.Join(lines, "\n")
}

func translateRelationOrPermission(tctx translationContext, relOrPermNode *dslNode) (*core.Relation, error) {
func translateRelationOrPermission(tctx *translationContext, relOrPermNode *dslNode) (*core.Relation, error) {
switch relOrPermNode.GetType() {
case dslshape.NodeTypeRelation:
rel, err := translateRelation(tctx, relOrPermNode)
Expand All @@ -370,7 +379,7 @@ func translateRelationOrPermission(tctx translationContext, relOrPermNode *dslNo
}
}

func translateRelation(tctx translationContext, relationNode *dslNode) (*core.Relation, error) {
func translateRelation(tctx *translationContext, relationNode *dslNode) (*core.Relation, error) {
relationName, err := relationNode.GetString(dslshape.NodePredicateName)
if err != nil {
return nil, relationNode.Errorf("invalid relation name: %w", err)
Expand Down Expand Up @@ -400,7 +409,7 @@ func translateRelation(tctx translationContext, relationNode *dslNode) (*core.Re
return relation, nil
}

func translatePermission(tctx translationContext, permissionNode *dslNode) (*core.Relation, error) {
func translatePermission(tctx *translationContext, permissionNode *dslNode) (*core.Relation, error) {
permissionName, err := permissionNode.GetString(dslshape.NodePredicateName)
if err != nil {
return nil, permissionNode.Errorf("invalid permission name: %w", err)
Expand Down Expand Up @@ -430,7 +439,7 @@ func translatePermission(tctx translationContext, permissionNode *dslNode) (*cor
return permission, nil
}

func translateBinary(tctx translationContext, expressionNode *dslNode) (*core.SetOperation_Child, *core.SetOperation_Child, error) {
func translateBinary(tctx *translationContext, expressionNode *dslNode) (*core.SetOperation_Child, *core.SetOperation_Child, error) {
leftChild, err := expressionNode.Lookup(dslshape.NodeExpressionPredicateLeftExpr)
if err != nil {
return nil, nil, err
Expand All @@ -454,7 +463,7 @@ func translateBinary(tctx translationContext, expressionNode *dslNode) (*core.Se
return leftOperation, rightOperation, nil
}

func translateExpression(tctx translationContext, expressionNode *dslNode) (*core.UsersetRewrite, error) {
func translateExpression(tctx *translationContext, expressionNode *dslNode) (*core.UsersetRewrite, error) {
translated, err := translateExpressionDirect(tctx, expressionNode)
if err != nil {
return translated, err
Expand Down Expand Up @@ -482,7 +491,7 @@ func collapseOps(op *core.SetOperation_Child, handler func(rewrite *core.Userset
return collapsed
}

func translateExpressionDirect(tctx translationContext, expressionNode *dslNode) (*core.UsersetRewrite, error) {
func translateExpressionDirect(tctx *translationContext, expressionNode *dslNode) (*core.UsersetRewrite, error) {
// For union and intersection, we collapse a tree of binary operations into a flat list containing child
// operations of the *same* type.
translate := func(
Expand Down Expand Up @@ -528,7 +537,7 @@ func translateExpressionDirect(tctx translationContext, expressionNode *dslNode)
}
}

func translateExpressionOperation(tctx translationContext, expressionOpNode *dslNode) (*core.SetOperation_Child, error) {
func translateExpressionOperation(tctx *translationContext, expressionOpNode *dslNode) (*core.SetOperation_Child, error) {
translated, err := translateExpressionOperationDirect(tctx, expressionOpNode)
if err != nil {
return translated, err
Expand All @@ -538,7 +547,7 @@ func translateExpressionOperation(tctx translationContext, expressionOpNode *dsl
return translated, nil
}

func translateExpressionOperationDirect(tctx translationContext, expressionOpNode *dslNode) (*core.SetOperation_Child, error) {
func translateExpressionOperationDirect(tctx *translationContext, expressionOpNode *dslNode) (*core.SetOperation_Child, error) {
switch expressionOpNode.GetType() {
case dslshape.NodeTypeIdentifier:
referencedRelationName, err := expressionOpNode.GetString(dslshape.NodeIdentiferPredicateValue)
Expand Down Expand Up @@ -605,7 +614,7 @@ func translateExpressionOperationDirect(tctx translationContext, expressionOpNod
}
}

func translateAllowedRelations(tctx translationContext, typeRefNode *dslNode) ([]*core.AllowedRelation, error) {
func translateAllowedRelations(tctx *translationContext, typeRefNode *dslNode) ([]*core.AllowedRelation, error) {
switch typeRefNode.GetType() {
case dslshape.NodeTypeTypeReference:
references := []*core.AllowedRelation{}
Expand All @@ -631,7 +640,7 @@ func translateAllowedRelations(tctx translationContext, typeRefNode *dslNode) ([
}
}

func translateSpecificTypeReference(tctx translationContext, typeRefNode *dslNode) (*core.AllowedRelation, error) {
func translateSpecificTypeReference(tctx *translationContext, typeRefNode *dslNode) (*core.AllowedRelation, error) {
typePath, err := typeRefNode.GetString(dslshape.NodeSpecificReferencePredicateType)
if err != nil {
return nil, typeRefNode.Errorf("invalid type name: %w", err)
Expand Down Expand Up @@ -680,11 +689,34 @@ func translateSpecificTypeReference(tctx translationContext, typeRefNode *dslNod
},
}

// Add the caveat(s), if any.
err = addWithCaveats(tctx, typeRefNode, ref)
if err != nil {
return nil, typeRefNode.Errorf("invalid caveat: %w", err)
}

// Add the expiration trait, if any.
if traitNode, err := typeRefNode.Lookup(dslshape.NodeSpecificReferencePredicateTrait); err == nil {
traitName, err := traitNode.GetString(dslshape.NodeTraitPredicateTrait)
if err != nil {
return nil, typeRefNode.Errorf("invalid trait: %w", err)
}

if traitName != "expiration" {
return nil, typeRefNode.Errorf("invalid trait: %s", traitName)
}

if !slices.Contains(tctx.allowedFlags, "expiration") {
return nil, typeRefNode.Errorf("expiration trait is not allowed")
}

if !slices.Contains(tctx.enabledFlags, "expiration") {
return nil, typeRefNode.Errorf("expiration flag is not enabled; add `use expiration` to top of file")
}

ref.RequiredExpiration = &core.ExpirationTrait{}
}

if !tctx.skipValidate {
if err := ref.Validate(); err != nil {
return nil, typeRefNode.Errorf("invalid type relation: %w", err)
Expand All @@ -695,7 +727,7 @@ func translateSpecificTypeReference(tctx translationContext, typeRefNode *dslNod
return ref, nil
}

func addWithCaveats(tctx translationContext, typeRefNode *dslNode, ref *core.AllowedRelation) error {
func addWithCaveats(tctx *translationContext, typeRefNode *dslNode, ref *core.AllowedRelation) error {
caveats := typeRefNode.List(dslshape.NodeSpecificReferencePredicateCaveat)
if len(caveats) == 0 {
return nil
Expand Down Expand Up @@ -809,7 +841,7 @@ func translateImports(itctx importResolutionContext, root *dslNode) error {
return nil
}

func collectPartials(tctx translationContext, rootNode *dslNode) error {
func collectPartials(tctx *translationContext, rootNode *dslNode) error {
for _, topLevelNode := range rootNode.GetChildren() {
if topLevelNode.GetType() == dslshape.NodeTypePartial {
err := translatePartial(tctx, topLevelNode)
Expand All @@ -825,7 +857,7 @@ func collectPartials(tctx translationContext, rootNode *dslNode) error {
}

// This function modifies the translation context, so we don't need to return anything from it.
func translatePartial(tctx translationContext, partialNode *dslNode) error {
func translatePartial(tctx *translationContext, partialNode *dslNode) error {
partialName, err := partialNode.GetString(dslshape.NodePartialPredicateName)
if err != nil {
return err
Expand Down Expand Up @@ -863,7 +895,7 @@ func translatePartial(tctx translationContext, partialNode *dslNode) error {
// NOTE: we treat partial references in definitions and partials differently because a missing partial
// reference in definition compilation is an error state, where a missing partial reference in a
// partial definition is an indeterminate state.
func translatePartialReference(tctx translationContext, partialReferenceNode *dslNode, errorOnMissingReference bool) ([]*core.Relation, string, error) {
func translatePartialReference(tctx *translationContext, partialReferenceNode *dslNode, errorOnMissingReference bool) ([]*core.Relation, string, error) {
name, err := partialReferenceNode.GetString(dslshape.NodePartialReferencePredicateName)
if err != nil {
return nil, "", err
Expand All @@ -879,3 +911,16 @@ func translatePartialReference(tctx translationContext, partialReferenceNode *ds
}
return relationsAndPermissions, "", nil
}

// Translate use node and add flag to list of enabled flags
func translateUseFlag(tctx *translationContext, useFlagNode *dslNode) error {
flagName, err := useFlagNode.GetString(dslshape.NodeUseFlagPredicateName)
if err != nil {
return err
}
if slices.Contains(tctx.enabledFlags, flagName) {
return useFlagNode.Errorf("found duplicate use flag: %s", flagName)
}
tctx.enabledFlags = append(tctx.enabledFlags, flagName)
return nil
}
Loading

0 comments on commit 22e351c

Please sign in to comment.