Skip to content

Commit 85d31f3

Browse files
committed
pgcdc: fix table metadata
1 parent 7aecd20 commit 85d31f3

File tree

5 files changed

+173
-10
lines changed

5 files changed

+173
-10
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ All notable changes to this project will be documented in this file.
1717
### Fixed
1818

1919
- Fix a snapshot stream consistency issue with `postgres_cdc` where data could be missed if writes where happening during the snapshot phase. (@rockwotj)
20+
- Fix an issue where `@table` metadata was quoted for the snapshot phase in `postgres_cdc`. (@rockwotj)
2021

2122
### Changed
2223

internal/impl/postgresql/integration_test.go

+116
Original file line numberDiff line numberDiff line change
@@ -924,3 +924,119 @@ read_until:
924924
require.Equal(t, expected, sequenceNumbers)
925925
batchMu.Unlock()
926926
}
927+
928+
func TestIntegrationPostgresMetadata(t *testing.T) {
929+
t.Parallel()
930+
integration.CheckSkip(t)
931+
pool, err := dockertest.NewPool("")
932+
require.NoError(t, err)
933+
934+
var (
935+
resource *dockertest.Resource
936+
db *sql.DB
937+
)
938+
939+
resource, db, err = ResourceWithPostgreSQLVersion(t, pool, "16")
940+
require.NoError(t, err)
941+
require.NoError(t, resource.Expire(120))
942+
943+
hostAndPort := resource.GetHostPort("5432/tcp")
944+
hostAndPortSplited := strings.Split(hostAndPort, ":")
945+
password := "l]YLSc|4[i56%{gY"
946+
947+
require.NoError(t, err)
948+
949+
_, err = db.Exec(`INSERT INTO "FlightsCompositePK" ("Seq", "Name", "CreatedAt") VALUES ($1, $2, $3);`, 1, "delta", "2006-01-02T15:04:05Z07:00")
950+
require.NoError(t, err)
951+
_, err = db.Exec(`INSERT INTO flights (name, created_at) VALUES ($1, $2);`, "delta", "2006-01-02T15:04:05Z07:00")
952+
require.NoError(t, err)
953+
954+
databaseURL := fmt.Sprintf("user=user_name password=%s dbname=dbname sslmode=disable host=%s port=%s", password, hostAndPortSplited[0], hostAndPortSplited[1])
955+
template := fmt.Sprintf(`
956+
postgres_cdc:
957+
dsn: %s
958+
slot_name: test_slot_native_decoder
959+
stream_snapshot: true
960+
snapshot_batch_size: 5
961+
schema: public
962+
tables:
963+
- '"FlightsCompositePK"'
964+
- flights
965+
`, databaseURL)
966+
967+
streamOutBuilder := service.NewStreamBuilder()
968+
require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: TRACE`))
969+
require.NoError(t, streamOutBuilder.AddInputYAML(template))
970+
require.NoError(t, streamOutBuilder.AddProcessorYAML(`mapping: 'root = @'`))
971+
972+
var outBatches []any
973+
var outBatchMut sync.Mutex
974+
require.NoError(t, streamOutBuilder.AddBatchConsumerFunc(func(c context.Context, batch service.MessageBatch) error {
975+
outBatchMut.Lock()
976+
defer outBatchMut.Unlock()
977+
for _, msg := range batch {
978+
data, err := msg.AsStructured()
979+
require.NoError(t, err)
980+
d := data.(map[string]any)
981+
if _, ok := d["lsn"]; ok {
982+
d["lsn"] = "XXX/XXX" // Consistent LSN for assertions below
983+
}
984+
outBatches = append(outBatches, data)
985+
}
986+
return nil
987+
}))
988+
989+
streamOut, err := streamOutBuilder.Build()
990+
require.NoError(t, err)
991+
992+
license.InjectTestService(streamOut.Resources())
993+
994+
go func() {
995+
_ = streamOut.Run(context.Background())
996+
}()
997+
998+
assert.Eventually(t, func() bool {
999+
outBatchMut.Lock()
1000+
defer outBatchMut.Unlock()
1001+
return len(outBatches) == 2
1002+
}, time.Second*25, time.Millisecond*100)
1003+
1004+
_, err = db.Exec(`INSERT INTO "FlightsCompositePK" ("Seq", "Name", "CreatedAt") VALUES ($1, $2, $3);`, 2, "bravo", "2006-01-02T15:04:05Z07:00")
1005+
require.NoError(t, err)
1006+
_, err = db.Exec(`INSERT INTO flights (name, created_at) VALUES ($1, $2);`, "bravo", "2006-01-02T15:04:05Z07:00")
1007+
require.NoError(t, err)
1008+
1009+
assert.EventuallyWithT(t, func(c *assert.CollectT) {
1010+
outBatchMut.Lock()
1011+
defer outBatchMut.Unlock()
1012+
assert.Len(c, outBatches, 4, "got: %#v", outBatches)
1013+
}, time.Second*25, time.Millisecond*100)
1014+
1015+
require.ElementsMatch(
1016+
t,
1017+
outBatches,
1018+
[]any{
1019+
map[string]any{
1020+
"operation": "read",
1021+
"table": "FlightsCompositePK",
1022+
},
1023+
map[string]any{
1024+
"operation": "read",
1025+
"table": "flights",
1026+
},
1027+
map[string]any{
1028+
"operation": "insert",
1029+
"table": "flights",
1030+
"lsn": "XXX/XXX",
1031+
},
1032+
map[string]any{
1033+
"operation": "insert",
1034+
"table": "FlightsCompositePK",
1035+
"lsn": "XXX/XXX",
1036+
},
1037+
},
1038+
)
1039+
1040+
require.NoError(t, streamOut.StopWithin(time.Second*10))
1041+
1042+
}

internal/impl/postgresql/pglogicalstream/logical_stream.go

+11-2
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,15 @@ func (s *Stream) processSnapshot(ctx context.Context, snapshotter *Snapshotter)
503503
wg.Go(func() (err error) {
504504
s.logger.Debugf("Processing snapshot for table: %v", table)
505505

506+
unquotedTable, err := sanitize.UnquotePostgresIdentifier(table.Table)
507+
if err != nil {
508+
return fmt.Errorf("unexpected failure to unquote table name: %w", err)
509+
}
510+
unquotedSchema, err := sanitize.UnquotePostgresIdentifier(table.Schema)
511+
if err != nil {
512+
return fmt.Errorf("unexpected failure to unquote schema name: %w", err)
513+
}
514+
506515
avgRowSizeBytes, numRows, err := snapshotter.tableStats(ctx, table)
507516
if err != nil {
508517
return fmt.Errorf("failed to calculate average row size for table %v: %w", table, err)
@@ -592,8 +601,8 @@ func (s *Stream) processSnapshot(ctx context.Context, snapshotter *Snapshotter)
592601
snapshotChangePacket := StreamMessage{
593602
LSN: nil,
594603
Operation: ReadOpType,
595-
Table: table.Table,
596-
Schema: table.Schema,
604+
Table: unquotedTable,
605+
Schema: unquotedSchema,
597606
Data: data,
598607
}
599608

internal/impl/postgresql/pglogicalstream/sanitize/sanitize.go

+26
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,32 @@ func QuotePostgresIdentifier(name string) string {
381381
return quoted.String()
382382
}
383383

384+
// UnquotePostgresIdentifier returns the valid unescaped identifier.
385+
func UnquotePostgresIdentifier(quoted string) (string, error) {
386+
var output strings.Builder
387+
// Default to assume we're just going to add quotes and there won't
388+
// be any double quotes inside the string that needs escaped.
389+
if !strings.HasPrefix(quoted, `"`) || !strings.HasSuffix(quoted, `"`) || len(quoted) < 2 {
390+
return "", errors.New("missing quotes for identifier")
391+
}
392+
unquoted := quoted[1 : len(quoted)-1]
393+
output.Grow(len(unquoted))
394+
for i := 0; i < len(unquoted); i++ {
395+
_ = output.WriteByte(unquoted[i])
396+
if unquoted[i] != '"' {
397+
continue
398+
}
399+
if i+1 >= len(unquoted) {
400+
return "", fmt.Errorf("invalid quoted identifier: %s", quoted)
401+
}
402+
if unquoted[i+1] != '"' {
403+
return "", fmt.Errorf("invalid quoted identifier: %s", quoted)
404+
}
405+
i++ // Skip over the next character to handle triple quotes
406+
}
407+
return output.String(), nil
408+
}
409+
384410
// NormalizePostgresIdentifier checks if a string is a valid PostgreSQL identifier
385411
// This follows PostgreSQL's standard naming rules
386412
func NormalizePostgresIdentifier(name string) (string, error) {

internal/impl/postgresql/pglogicalstream/sanitize/sanitize_test.go

+19-8
Original file line numberDiff line numberDiff line change
@@ -256,23 +256,31 @@ func TestQuerySanitize(t *testing.T) {
256256
}
257257

258258
func TestIdentifierValidation(t *testing.T) {
259-
quoted := []string{
260-
`"FooBar"`,
261-
`"Foo""Bar"`,
262-
`"Foo""""Bar"`,
259+
tests := []struct {
260+
quoted string
261+
unquoted string
262+
}{
263+
{quoted: `"FooBar"`, unquoted: "FooBar"},
264+
{quoted: `"Foo""Bar"`, unquoted: `Foo"Bar`},
265+
{quoted: `"Foo""""Bar"`, unquoted: `Foo""Bar`},
263266
}
264267

265-
for _, i := range quoted {
266-
i := i
267-
t.Run(i, func(t *testing.T) {
268-
_, err := sanitize.NormalizePostgresIdentifier(i)
268+
for _, testcase := range tests {
269+
testcase := testcase
270+
t.Run(testcase.unquoted, func(t *testing.T) {
271+
q, err := sanitize.NormalizePostgresIdentifier(testcase.quoted)
272+
require.NoError(t, err)
273+
require.Equal(t, testcase.quoted, q)
274+
r, err := sanitize.UnquotePostgresIdentifier(q)
269275
require.NoError(t, err)
276+
require.Equal(t, testcase.unquoted, r)
270277
})
271278
}
272279

273280
unquoted := []string{
274281
`_Foobar`,
275282
strings.Repeat("a", 63),
283+
strings.Repeat("A", 63),
276284
}
277285

278286
for _, i := range unquoted {
@@ -281,6 +289,9 @@ func TestIdentifierValidation(t *testing.T) {
281289
normalized, err := sanitize.NormalizePostgresIdentifier(i)
282290
require.NoError(t, err)
283291
require.Equal(t, strconv.Quote(strings.ToLower(i)), normalized)
292+
unquoted, err := sanitize.UnquotePostgresIdentifier(normalized)
293+
require.NoError(t, err)
294+
require.Equal(t, strings.ToLower(i), unquoted)
284295
})
285296
}
286297

0 commit comments

Comments
 (0)