Skip to content

Commit

Permalink
Merge pull request #2219 from josephschorr/strict-read-first
Browse files Browse the repository at this point in the history
Change strict reading PG to only return rows when valid
  • Loading branch information
josephschorr authored Jan 31, 2025
2 parents 542053f + 79c24fd commit 6db2a39
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 27 deletions.
65 changes: 65 additions & 0 deletions internal/datastore/postgres/postgres_shared_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/authzed/spicedb/internal/datastore/common"
pgcommon "github.com/authzed/spicedb/internal/datastore/postgres/common"
pgversion "github.com/authzed/spicedb/internal/datastore/postgres/version"
"github.com/authzed/spicedb/internal/datastore/proxy"
"github.com/authzed/spicedb/internal/testfixtures"
testdatastore "github.com/authzed/spicedb/internal/testserver/datastore"
"github.com/authzed/spicedb/pkg/datastore"
Expand Down Expand Up @@ -240,6 +241,16 @@ func testPostgresDatastore(t *testing.T, config postgresTestConfig) {
MigrationPhase(config.migrationPhase),
))

t.Run("TestStrictReadModeFallback", createReplicaDatastoreTest(
b,
StrictReadModeFallbackTest,
RevisionQuantization(0),
GCWindow(1000*time.Second),
GCInterval(veryLargeGCInterval),
WatchBufferLength(50),
MigrationPhase(config.migrationPhase),
))

t.Run("TestLocking", createMultiDatastoreTest(
b,
LockingTest,
Expand Down Expand Up @@ -1568,6 +1579,60 @@ func LockingTest(t *testing.T, ds datastore.Datastore, ds2 datastore.Datastore)
require.NoError(t, err)
}

func StrictReadModeFallbackTest(t *testing.T, primaryDS datastore.Datastore, unwrappedReplicaDS datastore.Datastore) {
require := require.New(t)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

// Write some relationships.
_, err := primaryDS.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error {
rtu := tuple.Touch(tuple.MustParse("resource:123#reader@user:456"))
return rwt.WriteRelationships(ctx, []tuple.RelationshipUpdate{rtu})
})
require.NoError(err)

// Get the HEAD revision.
lowestRevision, err := primaryDS.HeadRevision(ctx)
require.NoError(err)

// Wrap the replica DS.
replicaDS, err := proxy.NewStrictReplicatedDatastore(primaryDS, unwrappedReplicaDS.(datastore.StrictReadDatastore))
require.NoError(err)

// Perform a read at the head revision, which should succeed.
reader := replicaDS.SnapshotReader(lowestRevision)
it, err := reader.QueryRelationships(ctx, datastore.RelationshipsFilter{
OptionalResourceType: "resource",
})
require.NoError(err)

found, err := datastore.IteratorToSlice(it)
require.NoError(err)
require.NotEmpty(found)

// Perform a read at a manually constructed revision beyond head, which should fallback to the primary.
badRev := postgresRevision{
snapshot: pgSnapshot{
// NOTE: the struct defines this value as uint64, but the underlying
// revision is defined as an int64, so we run into an overflow issue
// if we try and use a big uint64.
xmin: 123456789,
xmax: 123456789,
},
}

limit := uint64(50)
it, err = replicaDS.SnapshotReader(badRev).QueryRelationships(ctx, datastore.RelationshipsFilter{
OptionalResourceType: "resource",
}, options.WithLimit(&limit))
require.NoError(err)

found2, err := datastore.IteratorToSlice(it)
require.NoError(err)
require.Equal(len(found), len(found2))
}

func StrictReadModeTest(t *testing.T, primaryDS datastore.Datastore, replicaDS datastore.Datastore) {
require := require.New(t)

Expand Down
28 changes: 22 additions & 6 deletions internal/datastore/postgres/strictreader.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/authzed/spicedb/internal/datastore/common"
pgxcommon "github.com/authzed/spicedb/internal/datastore/postgres/common"
"github.com/authzed/spicedb/pkg/spiceerrors"
)

const pgInvalidArgument = "22023"
Expand All @@ -26,15 +27,15 @@ type strictReaderQueryFuncs struct {
func (srqf strictReaderQueryFuncs) ExecFunc(ctx context.Context, tagFunc func(ctx context.Context, tag pgconn.CommandTag, err error) error, sql string, args ...any) error {
// NOTE: it is *required* for the pgx.QueryExecModeSimpleProtocol to be added as pgx will otherwise wrap
// the query as a prepared statement, which does *not* support running more than a single statement at a time.
return srqf.rewriteError(srqf.wrapped.ExecFunc(ctx, tagFunc, srqf.addAssertToSQL(sql), append([]interface{}{pgx.QueryExecModeSimpleProtocol}, args...)...))
return srqf.rewriteError(srqf.wrapped.ExecFunc(ctx, tagFunc, srqf.addAssertToSelectSQL(sql), append([]interface{}{pgx.QueryExecModeSimpleProtocol}, args...)...))
}

func (srqf strictReaderQueryFuncs) QueryFunc(ctx context.Context, rowsFunc func(ctx context.Context, rows pgx.Rows) error, sql string, args ...any) error {
return srqf.rewriteError(srqf.wrapped.QueryFunc(ctx, rowsFunc, srqf.addAssertToSQL(sql), append([]interface{}{pgx.QueryExecModeSimpleProtocol}, args...)...))
return srqf.rewriteError(srqf.wrapped.QueryFunc(ctx, rowsFunc, srqf.addAssertToSelectSQL(sql), append([]interface{}{pgx.QueryExecModeSimpleProtocol}, args...)...))
}

func (srqf strictReaderQueryFuncs) QueryRowFunc(ctx context.Context, rowFunc func(ctx context.Context, row pgx.Row) error, sql string, args ...any) error {
return srqf.rewriteError(srqf.wrapped.QueryRowFunc(ctx, rowFunc, srqf.addAssertToSQL(sql), append([]interface{}{pgx.QueryExecModeSimpleProtocol}, args...)...))
return srqf.rewriteError(srqf.wrapped.QueryRowFunc(ctx, rowFunc, srqf.addAssertToSelectSQL(sql), append([]interface{}{pgx.QueryExecModeSimpleProtocol}, args...)...))
}

func (srqf strictReaderQueryFuncs) rewriteError(err error) error {
Expand All @@ -53,13 +54,28 @@ func (srqf strictReaderQueryFuncs) rewriteError(err error) error {
return err
}

func (srqf strictReaderQueryFuncs) addAssertToSQL(sql string) string {
func (srqf strictReaderQueryFuncs) addAssertToSelectSQL(sql string) string {
spiceerrors.DebugAssert(func() bool {
return strings.HasPrefix(sql, "SELECT ")
}, "strictReaderQueryFuncs can only wrap SELECT queries")

// The assertion checks that the transaction is not reading from the future or from a
// transaction that is still in-progress on the replica. If the transaction is not yet
// available on the replica at all, the call to `pg_xact_status` will fail with an invalid
// argument error and a message indicating that the xid "is in the future". If the transaction
// does exist, but has not yet been committed (or aborted), the call to `pg_xact_status` will return
// "in progress". rewriteError will catch these cases and return a RevisionUnavailableError.
assertion := fmt.Sprintf(`; do $$ begin assert (select pg_xact_status(%d::text::xid8) != 'in progress'), 'replica missing revision';end;$$`, srqf.revision.snapshot.xmin-1)
return sql + assertion
//
// We run the query *first* (but filtered) as PGX will not be able to read rows if the assertion
// is run first. However, we do not want to return any rows if the assertion will fail, so we add it
// as a filter to the select as well.
wrapped := fmt.Sprintf(`
SELECT * FROM (%s) AS results WHERE pg_xact_status(%d::text::xid8) != 'in progress';
DO $$
BEGIN
ASSERT (select pg_xact_status(%d::text::xid8) != 'in progress'), 'replica missing revision';
END
$$;
`, sql, srqf.revision.snapshot.xmin-1, srqf.revision.snapshot.xmin-1)
return wrapped
}
46 changes: 31 additions & 15 deletions internal/datastore/proxy/checkingreplicated_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ import (
)

func TestCheckingReplicatedReaderFallsbackToPrimaryOnCheckRevisionFailure(t *testing.T) {
primary := fakeDatastore{true, revisionparsing.MustParseRevisionForTest("2")}
replica := fakeDatastore{false, revisionparsing.MustParseRevisionForTest("1")}
primary := fakeDatastore{"primary", revisionparsing.MustParseRevisionForTest("2")}
replica := fakeDatastore{"replica", revisionparsing.MustParseRevisionForTest("1")}

replicated, err := NewCheckingReplicatedDatastore(primary, replica)
require.NoError(t, err)
Expand All @@ -40,8 +40,8 @@ func TestCheckingReplicatedReaderFallsbackToPrimaryOnCheckRevisionFailure(t *tes
}

func TestCheckingReplicatedReaderFallsbackToPrimaryOnRevisionNotAvailableError(t *testing.T) {
primary := fakeDatastore{true, revisionparsing.MustParseRevisionForTest("2")}
replica := fakeDatastore{false, revisionparsing.MustParseRevisionForTest("1")}
primary := fakeDatastore{"primary", revisionparsing.MustParseRevisionForTest("2")}
replica := fakeDatastore{"replica", revisionparsing.MustParseRevisionForTest("1")}

replicated, err := NewCheckingReplicatedDatastore(primary, replica)
require.NoError(t, err)
Expand All @@ -55,8 +55,8 @@ func TestCheckingReplicatedReaderFallsbackToPrimaryOnRevisionNotAvailableError(t
func TestReplicatedReaderReturnsExpectedError(t *testing.T) {
for _, requireCheck := range []bool{true, false} {
t.Run(fmt.Sprintf("requireCheck=%v", requireCheck), func(t *testing.T) {
primary := fakeDatastore{true, revisionparsing.MustParseRevisionForTest("2")}
replica := fakeDatastore{false, revisionparsing.MustParseRevisionForTest("1")}
primary := fakeDatastore{"primary", revisionparsing.MustParseRevisionForTest("2")}
replica := fakeDatastore{"replica", revisionparsing.MustParseRevisionForTest("1")}

var ds datastore.Datastore
if requireCheck {
Expand All @@ -79,14 +79,14 @@ func TestReplicatedReaderReturnsExpectedError(t *testing.T) {
}

type fakeDatastore struct {
isPrimary bool
revision datastore.Revision
state string
revision datastore.Revision
}

func (f fakeDatastore) SnapshotReader(revision datastore.Revision) datastore.Reader {
return fakeSnapshotReader{
revision: revision,
isPrimary: f.isPrimary,
revision: revision,
state: f.state,
}
}

Expand Down Expand Up @@ -143,12 +143,12 @@ func (f fakeDatastore) IsStrictReadModeEnabled() bool {
}

type fakeSnapshotReader struct {
revision datastore.Revision
isPrimary bool
revision datastore.Revision
state string
}

func (fsr fakeSnapshotReader) LookupNamespacesWithNames(_ context.Context, nsNames []string) ([]datastore.RevisionedDefinition[*corev1.NamespaceDefinition], error) {
if fsr.isPrimary {
if fsr.state == "primary" {
return []datastore.RevisionedDefinition[*corev1.NamespaceDefinition]{
{
Definition: &corev1.NamespaceDefinition{
Expand All @@ -159,7 +159,7 @@ func (fsr fakeSnapshotReader) LookupNamespacesWithNames(_ context.Context, nsNam
}, nil
}

if !fsr.isPrimary && fsr.revision.GreaterThan(revisionparsing.MustParseRevisionForTest("2")) {
if fsr.revision.GreaterThan(revisionparsing.MustParseRevisionForTest("2")) {
return nil, common.NewRevisionUnavailableError(fmt.Errorf("revision not available"))
}

Expand Down Expand Up @@ -208,7 +208,7 @@ func (fakeSnapshotReader) LookupCounters(ctx context.Context) ([]datastore.Relat

func fakeIterator(fsr fakeSnapshotReader) datastore.RelationshipIterator {
return func(yield func(tuple.Relationship, error) bool) {
if fsr.isPrimary {
if fsr.state == "primary" {
if !yield(tuple.MustParse("resource:123#viewer@user:tom"), nil) {
return
}
Expand All @@ -218,6 +218,22 @@ func fakeIterator(fsr fakeSnapshotReader) datastore.RelationshipIterator {
return
}

if fsr.state == "replica-with-normal-error" {
if !yield(tuple.MustParse("resource:123#viewer@user:tom"), nil) {
return
}
if !yield(tuple.MustParse("resource:456#viewer@user:tom"), nil) {
return
}
if !yield(tuple.Relationship{}, fmt.Errorf("raising an expected error")) {
return
}
if !yield(tuple.MustParse("resource:789#viewer@user:tom"), nil) {
return
}
return
}

if fsr.revision.GreaterThan(revisionparsing.MustParseRevisionForTest("2")) {
yield(tuple.Relationship{}, common.NewRevisionUnavailableError(fmt.Errorf("revision not available")))
return
Expand Down
7 changes: 4 additions & 3 deletions internal/datastore/proxy/strictreplicated.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ func queryRelationships[F any, O any](
return nil, err
}

isFirstResult := true
beforeResultsYielded := true
requiresFallback := false
return func(yield func(tuple.Relationship, error) bool) {
replicaLoop:
Expand All @@ -143,7 +143,7 @@ func queryRelationships[F any, O any](
// If the RevisionUnavailableError is returned on the first result, we should fallback
// to the primary.
if errors.As(err, &common.RevisionUnavailableError{}) {
if !isFirstResult {
if !beforeResultsYielded {
yield(tuple.Relationship{}, spiceerrors.MustBugf("RevisionUnavailableError should only be returned on the first result"))
return
}
Expand All @@ -154,9 +154,10 @@ func queryRelationships[F any, O any](
if !yield(tuple.Relationship{}, err) {
return
}
continue
}

isFirstResult = false
beforeResultsYielded = false
if !yield(result, nil) {
return
}
Expand Down
34 changes: 31 additions & 3 deletions internal/datastore/proxy/strictreplicated_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
)

func TestStrictReplicatedReaderWithOnlyPrimary(t *testing.T) {
primary := fakeDatastore{true, revisionparsing.MustParseRevisionForTest("2")}
primary := fakeDatastore{"primary", revisionparsing.MustParseRevisionForTest("2")}

replicated, err := NewStrictReplicatedDatastore(primary)
require.NoError(t, err)
Expand All @@ -20,8 +20,8 @@ func TestStrictReplicatedReaderWithOnlyPrimary(t *testing.T) {
}

func TestStrictReplicatedQueryFallsbackToPrimaryOnRevisionNotAvailableError(t *testing.T) {
primary := fakeDatastore{true, revisionparsing.MustParseRevisionForTest("2")}
replica := fakeDatastore{false, revisionparsing.MustParseRevisionForTest("1")}
primary := fakeDatastore{"primary", revisionparsing.MustParseRevisionForTest("2")}
replica := fakeDatastore{"replica", revisionparsing.MustParseRevisionForTest("1")}

replicated, err := NewStrictReplicatedDatastore(primary, replica)
require.NoError(t, err)
Expand Down Expand Up @@ -87,3 +87,31 @@ func TestStrictReplicatedQueryFallsbackToPrimaryOnRevisionNotAvailableError(t *t
require.NoError(t, err)
require.Equal(t, 2, len(revfound))
}

func TestStrictReplicatedQueryNonFallbackError(t *testing.T) {
primary := fakeDatastore{"primary", revisionparsing.MustParseRevisionForTest("2")}
replica := fakeDatastore{"replica-with-normal-error", revisionparsing.MustParseRevisionForTest("1")}

replicated, err := NewStrictReplicatedDatastore(primary, replica)
require.NoError(t, err)

// Query the replicated, which should return the error.
reader := replicated.SnapshotReader(revisionparsing.MustParseRevisionForTest("3"))
iter, err := reader.QueryRelationships(context.Background(), datastore.RelationshipsFilter{
OptionalResourceType: "resource",
})
require.NoError(t, err)

relsCollected := 0
var errFound error
for _, err := range iter {
if err != nil {
errFound = err
} else {
relsCollected++
}
}

require.Equal(t, 3, relsCollected)
require.ErrorContains(t, errFound, "raising an expected error")
}

0 comments on commit 6db2a39

Please sign in to comment.