From 7e2c6b19ca55e463f7bb117cc375120fc5bff074 Mon Sep 17 00:00:00 2001 From: Brian Clark Date: Wed, 20 Sep 2023 10:40:39 -0400 Subject: [PATCH 1/3] Issue an explicit rollback if a migration statement returns any error --- database/pgx/pgx.go | 18 ++++++++++--- database/pgx/pgx_test.go | 42 ++++++++++++++++++++++++++++++ database/pgx/v5/pgx.go | 18 ++++++++++--- database/pgx/v5/pgx_test.go | 42 ++++++++++++++++++++++++++++++ database/postgres/TUTORIAL.md | 7 ++++- database/postgres/postgres.go | 17 ++++++++++-- database/postgres/postgres_test.go | 42 ++++++++++++++++++++++++++++++ 7 files changed, 177 insertions(+), 9 deletions(-) diff --git a/database/pgx/pgx.go b/database/pgx/pgx.go index deaca94ea..7266778de 100644 --- a/database/pgx/pgx.go +++ b/database/pgx/pgx.go @@ -285,7 +285,7 @@ func (p *Postgres) runStatement(statement []byte) error { return nil } if _, err := p.conn.ExecContext(ctx, query); err != nil { - + var migrationErr error if pgErr, ok := err.(*pgconn.PgError); ok { var line uint var col uint @@ -298,9 +298,21 @@ func (p *Postgres) runStatement(statement []byte) error { if pgErr.Detail != "" { message = fmt.Sprintf("%s, %s", message, pgErr.Detail) } - return database.Error{OrigErr: err, Err: message, Query: statement, Line: line} + migrationErr = database.Error{OrigErr: err, Err: message, Query: statement, Line: line} + } else { + migrationErr = database.Error{OrigErr: err, Err: "migration failed", Query: statement} + } + + // For safety, always issue a rollback on error. In multi-statement + // mode, this is necessary to make sure that the connection is not left + // in an aborted state. In single-statement mode, this will be a no-op + // outside of the implicit transaction block that was already rolled + // back. + if _, rollbackErr := p.conn.ExecContext(ctx, "ROLLBACK"); rollbackErr != nil { + rollbackErr = fmt.Errorf("failed to rollback migration tx: %w", rollbackErr) + return multierror.Append(migrationErr, rollbackErr) } - return database.Error{OrigErr: err, Err: "migration failed", Query: statement} + return migrationErr } return nil } diff --git a/database/pgx/pgx_test.go b/database/pgx/pgx_test.go index 5d7a5238e..0b0121ad2 100644 --- a/database/pgx/pgx_test.go +++ b/database/pgx/pgx_test.go @@ -167,6 +167,48 @@ func TestMultipleStatements(t *testing.T) { }) } +func TestMultipleStatementsError(t *testing.T) { + dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ip, port, err := c.FirstPort() + if err != nil { + t.Fatal(err) + } + + addr := pgConnectionString(ip, port) + p := &Postgres{} + d, err := p.Open(addr) + if err != nil { + t.Fatal(err) + } + defer func() { + if err := d.Close(); err != nil { + t.Error(err) + } + }() + + // Run a migration with explicit transaction that we expect to fail + err = d.Run(strings.NewReader("BEGIN; SELECT 1/0; COMMIT;")) + + // Migration should return expected error + var e database.Error + if !errors.As(err, &e) || err == nil { + t.Fatalf("Unexpected error, want migration error. Got: %#v", err) + } + if !strings.Contains(e.OrigErr.Error(), "division by zero") { + t.Fatalf("Migration error missing expected message. Got: %s", err) + } + + // Connection should still be usable after failed migration + var result int + if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT 123").Scan(&result); err != nil { + t.Fatalf("Unexpected error, want connection to be usable. Got: %s", err) + } + if result != 123 { + t.Fatalf("Unexpected result, want 123. Got: %d", result) + } + }) +} + func TestMultipleStatementsInMultiStatementMode(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { ip, port, err := c.FirstPort() diff --git a/database/pgx/v5/pgx.go b/database/pgx/v5/pgx.go index 1b5a6ea7a..401e98ad4 100644 --- a/database/pgx/v5/pgx.go +++ b/database/pgx/v5/pgx.go @@ -283,7 +283,7 @@ func (p *Postgres) runStatement(statement []byte) error { return nil } if _, err := p.conn.ExecContext(ctx, query); err != nil { - + var migrationErr error if pgErr, ok := err.(*pgconn.PgError); ok { var line uint var col uint @@ -296,9 +296,21 @@ func (p *Postgres) runStatement(statement []byte) error { if pgErr.Detail != "" { message = fmt.Sprintf("%s, %s", message, pgErr.Detail) } - return database.Error{OrigErr: err, Err: message, Query: statement, Line: line} + migrationErr = database.Error{OrigErr: err, Err: message, Query: statement, Line: line} + } else { + migrationErr = database.Error{OrigErr: err, Err: "migration failed", Query: statement} + } + + // For safety, always issue a rollback on error. In multi-statement + // mode, this is necessary to make sure that the connection is not left + // in an aborted state. In single-statement mode, this will be a no-op + // outside of the implicit transaction block that was already rolled + // back. + if _, rollbackErr := p.conn.ExecContext(ctx, "ROLLBACK"); rollbackErr != nil { + rollbackErr = fmt.Errorf("failed to rollback migration tx: %w", rollbackErr) + return multierror.Append(migrationErr, rollbackErr) } - return database.Error{OrigErr: err, Err: "migration failed", Query: statement} + return migrationErr } return nil } diff --git a/database/pgx/v5/pgx_test.go b/database/pgx/v5/pgx_test.go index c7339c4fc..f028804e6 100644 --- a/database/pgx/v5/pgx_test.go +++ b/database/pgx/v5/pgx_test.go @@ -168,6 +168,48 @@ func TestMultipleStatements(t *testing.T) { }) } +func TestMultipleStatementsError(t *testing.T) { + dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ip, port, err := c.FirstPort() + if err != nil { + t.Fatal(err) + } + + addr := pgConnectionString(ip, port) + p := &Postgres{} + d, err := p.Open(addr) + if err != nil { + t.Fatal(err) + } + defer func() { + if err := d.Close(); err != nil { + t.Error(err) + } + }() + + // Run a migration with explicit transaction that we expect to fail + err = d.Run(strings.NewReader("BEGIN; SELECT 1/0; COMMIT;")) + + // Migration should return expected error + var e database.Error + if !errors.As(err, &e) || err == nil { + t.Fatalf("Unexpected error, want migration error. Got: %s", err) + } + if !strings.Contains(e.OrigErr.Error(), "division by zero") { + t.Fatalf("Migration error missing expected message. Got: %s", err) + } + + // Connection should still be usable after failed migration + var result int + if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT 123").Scan(&result); err != nil { + t.Fatalf("Unexpected error, want connection to be usable. Got: %s", err) + } + if result != 123 { + t.Fatalf("Unexpected result, want 123. Got: %d", result) + } + }) +} + func TestMultipleStatementsInMultiStatementMode(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { ip, port, err := c.FirstPort() diff --git a/database/postgres/TUTORIAL.md b/database/postgres/TUTORIAL.md index 0f19c56ff..62b5c3d39 100644 --- a/database/postgres/TUTORIAL.md +++ b/database/postgres/TUTORIAL.md @@ -68,6 +68,12 @@ Make sure to check if your database changed as expected in this case as well. ## Database transactions +By default, all the statements in the migration file will be run inside one +implicit transaction. However, you can also use `BEGIN` and `COMMIT` statements +to explicitly control what transactions your migration uses (for example, if you +want to break your migration into multiple transactions to avoid holding onto a +lock for too long). + To show database transactions usage, let's create another set of migrations by running: ``` migrate create -ext sql -dir db/migrations -seq add_mood_to_users @@ -76,7 +82,6 @@ Again, it should create for us two migrations files: - 000002_add_mood_to_users.down.sql - 000002_add_mood_to_users.up.sql -In Postgres, when we want our queries to be done in a transaction, we need to wrap it with `BEGIN` and `COMMIT` commands. In our example, we are going to add a column to our database that can only accept enumerable values or NULL. Migration up: ``` diff --git a/database/postgres/postgres.go b/database/postgres/postgres.go index 9e6d6277f..4a0cc324d 100644 --- a/database/postgres/postgres.go +++ b/database/postgres/postgres.go @@ -296,6 +296,7 @@ func (p *Postgres) runStatement(statement []byte) error { return nil } if _, err := p.conn.ExecContext(ctx, query); err != nil { + var migrationErr error if pgErr, ok := err.(*pq.Error); ok { var line uint var col uint @@ -312,9 +313,21 @@ func (p *Postgres) runStatement(statement []byte) error { if pgErr.Detail != "" { message = fmt.Sprintf("%s, %s", message, pgErr.Detail) } - return database.Error{OrigErr: err, Err: message, Query: statement, Line: line} + migrationErr = database.Error{OrigErr: err, Err: message, Query: statement, Line: line} + } else { + migrationErr = database.Error{OrigErr: err, Err: "migration failed", Query: statement} } - return database.Error{OrigErr: err, Err: "migration failed", Query: statement} + + // For safety, always issue a rollback on error. In multi-statement + // mode, this is necessary to make sure that the connection is not left + // in an aborted state. In single-statement mode, this will be a no-op + // outside of the implicit transaction block that was already rolled + // back. + if _, rollbackErr := p.conn.ExecContext(ctx, "ROLLBACK"); rollbackErr != nil { + rollbackErr = fmt.Errorf("failed to rollback migration tx: %w", rollbackErr) + return multierror.Append(migrationErr, rollbackErr) + } + return migrationErr } return nil } diff --git a/database/postgres/postgres_test.go b/database/postgres/postgres_test.go index 65395cc7e..eeccb74ae 100644 --- a/database/postgres/postgres_test.go +++ b/database/postgres/postgres_test.go @@ -165,6 +165,48 @@ func TestMultipleStatements(t *testing.T) { }) } +func TestMultipleStatementsError(t *testing.T) { + dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ip, port, err := c.FirstPort() + if err != nil { + t.Fatal(err) + } + + addr := pgConnectionString(ip, port) + p := &Postgres{} + d, err := p.Open(addr) + if err != nil { + t.Fatal(err) + } + defer func() { + if err := d.Close(); err != nil { + t.Error(err) + } + }() + + // Run a migration with explicit transaction that we expect to fail + err = d.Run(strings.NewReader("BEGIN; SELECT 1/0; COMMIT;")) + + // Migration should return expected error + var e database.Error + if !errors.As(err, &e) || err == nil { + t.Fatalf("Unexpected error, want migration error. Got: %s", err) + } + if !strings.Contains(e.OrigErr.Error(), "division by zero") { + t.Fatalf("Migration error missing expected message. Got: %s", err) + } + + // Connection should still be usable after failed migration + var result int + if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT 123").Scan(&result); err != nil { + t.Fatalf("Unexpected error, want connection to be usable. Got: %s", err) + } + if result != 123 { + t.Fatalf("Unexpected result, want 123. Got: %d", result) + } + }) +} + func TestMultipleStatementsInMultiStatementMode(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { ip, port, err := c.FirstPort() From bfdd2f0b5fe1c6952d155c5f49f5d9c2f24a396b Mon Sep 17 00:00:00 2001 From: Brian Clark Date: Fri, 29 Sep 2023 16:06:06 -0400 Subject: [PATCH 2/3] Change wording of migration tx rollback error message to reduce noise --- database/pgx/pgx.go | 2 +- database/pgx/v5/pgx.go | 2 +- database/postgres/postgres.go | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/database/pgx/pgx.go b/database/pgx/pgx.go index 7266778de..a802f2109 100644 --- a/database/pgx/pgx.go +++ b/database/pgx/pgx.go @@ -309,7 +309,7 @@ func (p *Postgres) runStatement(statement []byte) error { // outside of the implicit transaction block that was already rolled // back. if _, rollbackErr := p.conn.ExecContext(ctx, "ROLLBACK"); rollbackErr != nil { - rollbackErr = fmt.Errorf("failed to rollback migration tx: %w", rollbackErr) + rollbackErr = fmt.Errorf("rolling back migration tx: %w", rollbackErr) return multierror.Append(migrationErr, rollbackErr) } return migrationErr diff --git a/database/pgx/v5/pgx.go b/database/pgx/v5/pgx.go index 401e98ad4..045c3e875 100644 --- a/database/pgx/v5/pgx.go +++ b/database/pgx/v5/pgx.go @@ -307,7 +307,7 @@ func (p *Postgres) runStatement(statement []byte) error { // outside of the implicit transaction block that was already rolled // back. if _, rollbackErr := p.conn.ExecContext(ctx, "ROLLBACK"); rollbackErr != nil { - rollbackErr = fmt.Errorf("failed to rollback migration tx: %w", rollbackErr) + rollbackErr = fmt.Errorf("rolling back migration tx: %w", rollbackErr) return multierror.Append(migrationErr, rollbackErr) } return migrationErr diff --git a/database/postgres/postgres.go b/database/postgres/postgres.go index 4a0cc324d..de1e8d93a 100644 --- a/database/postgres/postgres.go +++ b/database/postgres/postgres.go @@ -324,7 +324,7 @@ func (p *Postgres) runStatement(statement []byte) error { // outside of the implicit transaction block that was already rolled // back. if _, rollbackErr := p.conn.ExecContext(ctx, "ROLLBACK"); rollbackErr != nil { - rollbackErr = fmt.Errorf("failed to rollback migration tx: %w", rollbackErr) + rollbackErr = fmt.Errorf("rolling back migration tx: %w", rollbackErr) return multierror.Append(migrationErr, rollbackErr) } return migrationErr From 4813cdb3d9e407aef2aafc9fd1fcf3d12a6e8f71 Mon Sep 17 00:00:00 2001 From: Brian Clark Date: Fri, 29 Sep 2023 16:18:05 -0400 Subject: [PATCH 3/3] Define migrationErr with non-nil value from the get-go --- database/pgx/pgx.go | 4 +--- database/pgx/v5/pgx.go | 4 +--- database/postgres/postgres.go | 4 +--- 3 files changed, 3 insertions(+), 9 deletions(-) diff --git a/database/pgx/pgx.go b/database/pgx/pgx.go index a802f2109..5f95ca1c8 100644 --- a/database/pgx/pgx.go +++ b/database/pgx/pgx.go @@ -285,7 +285,7 @@ func (p *Postgres) runStatement(statement []byte) error { return nil } if _, err := p.conn.ExecContext(ctx, query); err != nil { - var migrationErr error + migrationErr := database.Error{OrigErr: err, Err: "migration failed", Query: statement} if pgErr, ok := err.(*pgconn.PgError); ok { var line uint var col uint @@ -299,8 +299,6 @@ func (p *Postgres) runStatement(statement []byte) error { message = fmt.Sprintf("%s, %s", message, pgErr.Detail) } migrationErr = database.Error{OrigErr: err, Err: message, Query: statement, Line: line} - } else { - migrationErr = database.Error{OrigErr: err, Err: "migration failed", Query: statement} } // For safety, always issue a rollback on error. In multi-statement diff --git a/database/pgx/v5/pgx.go b/database/pgx/v5/pgx.go index 045c3e875..daf6d784c 100644 --- a/database/pgx/v5/pgx.go +++ b/database/pgx/v5/pgx.go @@ -283,7 +283,7 @@ func (p *Postgres) runStatement(statement []byte) error { return nil } if _, err := p.conn.ExecContext(ctx, query); err != nil { - var migrationErr error + migrationErr := database.Error{OrigErr: err, Err: "migration failed", Query: statement} if pgErr, ok := err.(*pgconn.PgError); ok { var line uint var col uint @@ -297,8 +297,6 @@ func (p *Postgres) runStatement(statement []byte) error { message = fmt.Sprintf("%s, %s", message, pgErr.Detail) } migrationErr = database.Error{OrigErr: err, Err: message, Query: statement, Line: line} - } else { - migrationErr = database.Error{OrigErr: err, Err: "migration failed", Query: statement} } // For safety, always issue a rollback on error. In multi-statement diff --git a/database/postgres/postgres.go b/database/postgres/postgres.go index de1e8d93a..8ad8f2cc8 100644 --- a/database/postgres/postgres.go +++ b/database/postgres/postgres.go @@ -296,7 +296,7 @@ func (p *Postgres) runStatement(statement []byte) error { return nil } if _, err := p.conn.ExecContext(ctx, query); err != nil { - var migrationErr error + migrationErr := database.Error{OrigErr: err, Err: "migration failed", Query: statement} if pgErr, ok := err.(*pq.Error); ok { var line uint var col uint @@ -314,8 +314,6 @@ func (p *Postgres) runStatement(statement []byte) error { message = fmt.Sprintf("%s, %s", message, pgErr.Detail) } migrationErr = database.Error{OrigErr: err, Err: message, Query: statement, Line: line} - } else { - migrationErr = database.Error{OrigErr: err, Err: "migration failed", Query: statement} } // For safety, always issue a rollback on error. In multi-statement