From 79c24fd8e59daef239bc40a14c0a20a97e5d39d7 Mon Sep 17 00:00:00 2001 From: Joseph Schorr Date: Thu, 30 Jan 2025 14:32:15 -0500 Subject: [PATCH] Change strict reading PG to only return rows when valid This is necessary for the error and retry logic to work in the strict read proxy --- .../postgres/postgres_shared_test.go | 65 +++++++++++++++++++ internal/datastore/postgres/strictreader.go | 28 ++++++-- .../proxy/checkingreplicated_test.go | 46 ++++++++----- internal/datastore/proxy/strictreplicated.go | 7 +- .../datastore/proxy/strictreplicated_test.go | 34 +++++++++- 5 files changed, 153 insertions(+), 27 deletions(-) diff --git a/internal/datastore/postgres/postgres_shared_test.go b/internal/datastore/postgres/postgres_shared_test.go index 489495233d..e6561b6cd1 100644 --- a/internal/datastore/postgres/postgres_shared_test.go +++ b/internal/datastore/postgres/postgres_shared_test.go @@ -25,6 +25,7 @@ import ( "github.com/authzed/spicedb/internal/datastore/common" pgcommon "github.com/authzed/spicedb/internal/datastore/postgres/common" pgversion "github.com/authzed/spicedb/internal/datastore/postgres/version" + "github.com/authzed/spicedb/internal/datastore/proxy" "github.com/authzed/spicedb/internal/testfixtures" testdatastore "github.com/authzed/spicedb/internal/testserver/datastore" "github.com/authzed/spicedb/pkg/datastore" @@ -240,6 +241,16 @@ func testPostgresDatastore(t *testing.T, config postgresTestConfig) { MigrationPhase(config.migrationPhase), )) + t.Run("TestStrictReadModeFallback", createReplicaDatastoreTest( + b, + StrictReadModeFallbackTest, + RevisionQuantization(0), + GCWindow(1000*time.Second), + GCInterval(veryLargeGCInterval), + WatchBufferLength(50), + MigrationPhase(config.migrationPhase), + )) + t.Run("TestLocking", createMultiDatastoreTest( b, LockingTest, @@ -1568,6 +1579,60 @@ func LockingTest(t *testing.T, ds datastore.Datastore, ds2 datastore.Datastore) require.NoError(t, err) } +func StrictReadModeFallbackTest(t *testing.T, primaryDS datastore.Datastore, unwrappedReplicaDS datastore.Datastore) { + require := require.New(t) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Write some relationships. + _, err := primaryDS.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + rtu := tuple.Touch(tuple.MustParse("resource:123#reader@user:456")) + return rwt.WriteRelationships(ctx, []tuple.RelationshipUpdate{rtu}) + }) + require.NoError(err) + + // Get the HEAD revision. + lowestRevision, err := primaryDS.HeadRevision(ctx) + require.NoError(err) + + // Wrap the replica DS. + replicaDS, err := proxy.NewStrictReplicatedDatastore(primaryDS, unwrappedReplicaDS.(datastore.StrictReadDatastore)) + require.NoError(err) + + // Perform a read at the head revision, which should succeed. + reader := replicaDS.SnapshotReader(lowestRevision) + it, err := reader.QueryRelationships(ctx, datastore.RelationshipsFilter{ + OptionalResourceType: "resource", + }) + require.NoError(err) + + found, err := datastore.IteratorToSlice(it) + require.NoError(err) + require.NotEmpty(found) + + // Perform a read at a manually constructed revision beyond head, which should fallback to the primary. + badRev := postgresRevision{ + snapshot: pgSnapshot{ + // NOTE: the struct defines this value as uint64, but the underlying + // revision is defined as an int64, so we run into an overflow issue + // if we try and use a big uint64. + xmin: 123456789, + xmax: 123456789, + }, + } + + limit := uint64(50) + it, err = replicaDS.SnapshotReader(badRev).QueryRelationships(ctx, datastore.RelationshipsFilter{ + OptionalResourceType: "resource", + }, options.WithLimit(&limit)) + require.NoError(err) + + found2, err := datastore.IteratorToSlice(it) + require.NoError(err) + require.Equal(len(found), len(found2)) +} + func StrictReadModeTest(t *testing.T, primaryDS datastore.Datastore, replicaDS datastore.Datastore) { require := require.New(t) diff --git a/internal/datastore/postgres/strictreader.go b/internal/datastore/postgres/strictreader.go index 15bb52ab39..f7450d0efa 100644 --- a/internal/datastore/postgres/strictreader.go +++ b/internal/datastore/postgres/strictreader.go @@ -11,6 +11,7 @@ import ( "github.com/authzed/spicedb/internal/datastore/common" pgxcommon "github.com/authzed/spicedb/internal/datastore/postgres/common" + "github.com/authzed/spicedb/pkg/spiceerrors" ) const pgInvalidArgument = "22023" @@ -26,15 +27,15 @@ type strictReaderQueryFuncs struct { func (srqf strictReaderQueryFuncs) ExecFunc(ctx context.Context, tagFunc func(ctx context.Context, tag pgconn.CommandTag, err error) error, sql string, args ...any) error { // NOTE: it is *required* for the pgx.QueryExecModeSimpleProtocol to be added as pgx will otherwise wrap // the query as a prepared statement, which does *not* support running more than a single statement at a time. - return srqf.rewriteError(srqf.wrapped.ExecFunc(ctx, tagFunc, srqf.addAssertToSQL(sql), append([]interface{}{pgx.QueryExecModeSimpleProtocol}, args...)...)) + return srqf.rewriteError(srqf.wrapped.ExecFunc(ctx, tagFunc, srqf.addAssertToSelectSQL(sql), append([]interface{}{pgx.QueryExecModeSimpleProtocol}, args...)...)) } func (srqf strictReaderQueryFuncs) QueryFunc(ctx context.Context, rowsFunc func(ctx context.Context, rows pgx.Rows) error, sql string, args ...any) error { - return srqf.rewriteError(srqf.wrapped.QueryFunc(ctx, rowsFunc, srqf.addAssertToSQL(sql), append([]interface{}{pgx.QueryExecModeSimpleProtocol}, args...)...)) + return srqf.rewriteError(srqf.wrapped.QueryFunc(ctx, rowsFunc, srqf.addAssertToSelectSQL(sql), append([]interface{}{pgx.QueryExecModeSimpleProtocol}, args...)...)) } func (srqf strictReaderQueryFuncs) QueryRowFunc(ctx context.Context, rowFunc func(ctx context.Context, row pgx.Row) error, sql string, args ...any) error { - return srqf.rewriteError(srqf.wrapped.QueryRowFunc(ctx, rowFunc, srqf.addAssertToSQL(sql), append([]interface{}{pgx.QueryExecModeSimpleProtocol}, args...)...)) + return srqf.rewriteError(srqf.wrapped.QueryRowFunc(ctx, rowFunc, srqf.addAssertToSelectSQL(sql), append([]interface{}{pgx.QueryExecModeSimpleProtocol}, args...)...)) } func (srqf strictReaderQueryFuncs) rewriteError(err error) error { @@ -53,13 +54,28 @@ func (srqf strictReaderQueryFuncs) rewriteError(err error) error { return err } -func (srqf strictReaderQueryFuncs) addAssertToSQL(sql string) string { +func (srqf strictReaderQueryFuncs) addAssertToSelectSQL(sql string) string { + spiceerrors.DebugAssert(func() bool { + return strings.HasPrefix(sql, "SELECT ") + }, "strictReaderQueryFuncs can only wrap SELECT queries") + // The assertion checks that the transaction is not reading from the future or from a // transaction that is still in-progress on the replica. If the transaction is not yet // available on the replica at all, the call to `pg_xact_status` will fail with an invalid // argument error and a message indicating that the xid "is in the future". If the transaction // does exist, but has not yet been committed (or aborted), the call to `pg_xact_status` will return // "in progress". rewriteError will catch these cases and return a RevisionUnavailableError. - assertion := fmt.Sprintf(`; do $$ begin assert (select pg_xact_status(%d::text::xid8) != 'in progress'), 'replica missing revision';end;$$`, srqf.revision.snapshot.xmin-1) - return sql + assertion + // + // We run the query *first* (but filtered) as PGX will not be able to read rows if the assertion + // is run first. However, we do not want to return any rows if the assertion will fail, so we add it + // as a filter to the select as well. + wrapped := fmt.Sprintf(` + SELECT * FROM (%s) AS results WHERE pg_xact_status(%d::text::xid8) != 'in progress'; + DO $$ + BEGIN + ASSERT (select pg_xact_status(%d::text::xid8) != 'in progress'), 'replica missing revision'; + END + $$; + `, sql, srqf.revision.snapshot.xmin-1, srqf.revision.snapshot.xmin-1) + return wrapped } diff --git a/internal/datastore/proxy/checkingreplicated_test.go b/internal/datastore/proxy/checkingreplicated_test.go index a475f63ac9..fdf5bd528c 100644 --- a/internal/datastore/proxy/checkingreplicated_test.go +++ b/internal/datastore/proxy/checkingreplicated_test.go @@ -16,8 +16,8 @@ import ( ) func TestCheckingReplicatedReaderFallsbackToPrimaryOnCheckRevisionFailure(t *testing.T) { - primary := fakeDatastore{true, revisionparsing.MustParseRevisionForTest("2")} - replica := fakeDatastore{false, revisionparsing.MustParseRevisionForTest("1")} + primary := fakeDatastore{"primary", revisionparsing.MustParseRevisionForTest("2")} + replica := fakeDatastore{"replica", revisionparsing.MustParseRevisionForTest("1")} replicated, err := NewCheckingReplicatedDatastore(primary, replica) require.NoError(t, err) @@ -40,8 +40,8 @@ func TestCheckingReplicatedReaderFallsbackToPrimaryOnCheckRevisionFailure(t *tes } func TestCheckingReplicatedReaderFallsbackToPrimaryOnRevisionNotAvailableError(t *testing.T) { - primary := fakeDatastore{true, revisionparsing.MustParseRevisionForTest("2")} - replica := fakeDatastore{false, revisionparsing.MustParseRevisionForTest("1")} + primary := fakeDatastore{"primary", revisionparsing.MustParseRevisionForTest("2")} + replica := fakeDatastore{"replica", revisionparsing.MustParseRevisionForTest("1")} replicated, err := NewCheckingReplicatedDatastore(primary, replica) require.NoError(t, err) @@ -55,8 +55,8 @@ func TestCheckingReplicatedReaderFallsbackToPrimaryOnRevisionNotAvailableError(t func TestReplicatedReaderReturnsExpectedError(t *testing.T) { for _, requireCheck := range []bool{true, false} { t.Run(fmt.Sprintf("requireCheck=%v", requireCheck), func(t *testing.T) { - primary := fakeDatastore{true, revisionparsing.MustParseRevisionForTest("2")} - replica := fakeDatastore{false, revisionparsing.MustParseRevisionForTest("1")} + primary := fakeDatastore{"primary", revisionparsing.MustParseRevisionForTest("2")} + replica := fakeDatastore{"replica", revisionparsing.MustParseRevisionForTest("1")} var ds datastore.Datastore if requireCheck { @@ -79,14 +79,14 @@ func TestReplicatedReaderReturnsExpectedError(t *testing.T) { } type fakeDatastore struct { - isPrimary bool - revision datastore.Revision + state string + revision datastore.Revision } func (f fakeDatastore) SnapshotReader(revision datastore.Revision) datastore.Reader { return fakeSnapshotReader{ - revision: revision, - isPrimary: f.isPrimary, + revision: revision, + state: f.state, } } @@ -143,12 +143,12 @@ func (f fakeDatastore) IsStrictReadModeEnabled() bool { } type fakeSnapshotReader struct { - revision datastore.Revision - isPrimary bool + revision datastore.Revision + state string } func (fsr fakeSnapshotReader) LookupNamespacesWithNames(_ context.Context, nsNames []string) ([]datastore.RevisionedDefinition[*corev1.NamespaceDefinition], error) { - if fsr.isPrimary { + if fsr.state == "primary" { return []datastore.RevisionedDefinition[*corev1.NamespaceDefinition]{ { Definition: &corev1.NamespaceDefinition{ @@ -159,7 +159,7 @@ func (fsr fakeSnapshotReader) LookupNamespacesWithNames(_ context.Context, nsNam }, nil } - if !fsr.isPrimary && fsr.revision.GreaterThan(revisionparsing.MustParseRevisionForTest("2")) { + if fsr.revision.GreaterThan(revisionparsing.MustParseRevisionForTest("2")) { return nil, common.NewRevisionUnavailableError(fmt.Errorf("revision not available")) } @@ -208,7 +208,7 @@ func (fakeSnapshotReader) LookupCounters(ctx context.Context) ([]datastore.Relat func fakeIterator(fsr fakeSnapshotReader) datastore.RelationshipIterator { return func(yield func(tuple.Relationship, error) bool) { - if fsr.isPrimary { + if fsr.state == "primary" { if !yield(tuple.MustParse("resource:123#viewer@user:tom"), nil) { return } @@ -218,6 +218,22 @@ func fakeIterator(fsr fakeSnapshotReader) datastore.RelationshipIterator { return } + if fsr.state == "replica-with-normal-error" { + if !yield(tuple.MustParse("resource:123#viewer@user:tom"), nil) { + return + } + if !yield(tuple.MustParse("resource:456#viewer@user:tom"), nil) { + return + } + if !yield(tuple.Relationship{}, fmt.Errorf("raising an expected error")) { + return + } + if !yield(tuple.MustParse("resource:789#viewer@user:tom"), nil) { + return + } + return + } + if fsr.revision.GreaterThan(revisionparsing.MustParseRevisionForTest("2")) { yield(tuple.Relationship{}, common.NewRevisionUnavailableError(fmt.Errorf("revision not available"))) return diff --git a/internal/datastore/proxy/strictreplicated.go b/internal/datastore/proxy/strictreplicated.go index d52a1b8fa5..d84357e647 100644 --- a/internal/datastore/proxy/strictreplicated.go +++ b/internal/datastore/proxy/strictreplicated.go @@ -134,7 +134,7 @@ func queryRelationships[F any, O any]( return nil, err } - isFirstResult := true + beforeResultsYielded := true requiresFallback := false return func(yield func(tuple.Relationship, error) bool) { replicaLoop: @@ -143,7 +143,7 @@ func queryRelationships[F any, O any]( // If the RevisionUnavailableError is returned on the first result, we should fallback // to the primary. if errors.As(err, &common.RevisionUnavailableError{}) { - if !isFirstResult { + if !beforeResultsYielded { yield(tuple.Relationship{}, spiceerrors.MustBugf("RevisionUnavailableError should only be returned on the first result")) return } @@ -154,9 +154,10 @@ func queryRelationships[F any, O any]( if !yield(tuple.Relationship{}, err) { return } + continue } - isFirstResult = false + beforeResultsYielded = false if !yield(result, nil) { return } diff --git a/internal/datastore/proxy/strictreplicated_test.go b/internal/datastore/proxy/strictreplicated_test.go index 663bebd3c4..0994edf864 100644 --- a/internal/datastore/proxy/strictreplicated_test.go +++ b/internal/datastore/proxy/strictreplicated_test.go @@ -11,7 +11,7 @@ import ( ) func TestStrictReplicatedReaderWithOnlyPrimary(t *testing.T) { - primary := fakeDatastore{true, revisionparsing.MustParseRevisionForTest("2")} + primary := fakeDatastore{"primary", revisionparsing.MustParseRevisionForTest("2")} replicated, err := NewStrictReplicatedDatastore(primary) require.NoError(t, err) @@ -20,8 +20,8 @@ func TestStrictReplicatedReaderWithOnlyPrimary(t *testing.T) { } func TestStrictReplicatedQueryFallsbackToPrimaryOnRevisionNotAvailableError(t *testing.T) { - primary := fakeDatastore{true, revisionparsing.MustParseRevisionForTest("2")} - replica := fakeDatastore{false, revisionparsing.MustParseRevisionForTest("1")} + primary := fakeDatastore{"primary", revisionparsing.MustParseRevisionForTest("2")} + replica := fakeDatastore{"replica", revisionparsing.MustParseRevisionForTest("1")} replicated, err := NewStrictReplicatedDatastore(primary, replica) require.NoError(t, err) @@ -87,3 +87,31 @@ func TestStrictReplicatedQueryFallsbackToPrimaryOnRevisionNotAvailableError(t *t require.NoError(t, err) require.Equal(t, 2, len(revfound)) } + +func TestStrictReplicatedQueryNonFallbackError(t *testing.T) { + primary := fakeDatastore{"primary", revisionparsing.MustParseRevisionForTest("2")} + replica := fakeDatastore{"replica-with-normal-error", revisionparsing.MustParseRevisionForTest("1")} + + replicated, err := NewStrictReplicatedDatastore(primary, replica) + require.NoError(t, err) + + // Query the replicated, which should return the error. + reader := replicated.SnapshotReader(revisionparsing.MustParseRevisionForTest("3")) + iter, err := reader.QueryRelationships(context.Background(), datastore.RelationshipsFilter{ + OptionalResourceType: "resource", + }) + require.NoError(t, err) + + relsCollected := 0 + var errFound error + for _, err := range iter { + if err != nil { + errFound = err + } else { + relsCollected++ + } + } + + require.Equal(t, 3, relsCollected) + require.ErrorContains(t, errFound, "raising an expected error") +}