Skip to content

Commit 8d90b84

Browse files
committed
Issue an explicit rollback if a migration statement returns any error
1 parent 856ea12 commit 8d90b84

File tree

5 files changed

+93
-9
lines changed

5 files changed

+93
-9
lines changed

database/pgx/pgx.go

+15-3
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ func (p *Postgres) runStatement(statement []byte) error {
285285
return nil
286286
}
287287
if _, err := p.conn.ExecContext(ctx, query); err != nil {
288-
288+
var migrationErr error
289289
if pgErr, ok := err.(*pgconn.PgError); ok {
290290
var line uint
291291
var col uint
@@ -298,9 +298,21 @@ func (p *Postgres) runStatement(statement []byte) error {
298298
if pgErr.Detail != "" {
299299
message = fmt.Sprintf("%s, %s", message, pgErr.Detail)
300300
}
301-
return database.Error{OrigErr: err, Err: message, Query: statement, Line: line}
301+
migrationErr = database.Error{OrigErr: err, Err: message, Query: statement, Line: line}
302+
} else {
303+
migrationErr = database.Error{OrigErr: err, Err: "migration failed", Query: statement}
304+
}
305+
306+
// For safety, always issue a rollback on error. In multi-statement
307+
// mode, this is necessary to make sure that the connection is not left
308+
// in an aborted state. In single-statement mode, this will be a no-op
309+
// outside of the implicit transaction block that was already rolled
310+
// back.
311+
if _, rollbackErr := p.conn.ExecContext(ctx, "ROLLBACK"); rollbackErr != nil {
312+
rollbackErr = fmt.Errorf("failed to rollback migration tx: %w", rollbackErr)
313+
return multierror.Append(migrationErr, rollbackErr)
302314
}
303-
return database.Error{OrigErr: err, Err: "migration failed", Query: statement}
315+
return migrationErr
304316
}
305317
return nil
306318
}

database/pgx/pgx_test.go

+42
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,48 @@ func TestMultipleStatements(t *testing.T) {
167167
})
168168
}
169169

170+
func TestMultipleStatementsError(t *testing.T) {
171+
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
172+
ip, port, err := c.FirstPort()
173+
if err != nil {
174+
t.Fatal(err)
175+
}
176+
177+
addr := pgConnectionString(ip, port)
178+
p := &Postgres{}
179+
d, err := p.Open(addr)
180+
if err != nil {
181+
t.Fatal(err)
182+
}
183+
defer func() {
184+
if err := d.Close(); err != nil {
185+
t.Error(err)
186+
}
187+
}()
188+
189+
// Run a migration with explicit transaction that we expect to fail
190+
migrationErr := d.Run(strings.NewReader("BEGIN; SELECT 1/0; COMMIT;"))
191+
192+
// Migration should return expected error
193+
var e *database.Error
194+
if !errors.As(migrationErr, &e) || err == nil {
195+
t.Fatalf("Unexpected error, want migration error. Got: %s", err)
196+
}
197+
if !strings.Contains(e.OrigErr.Error(), "division by zero") {
198+
t.Fatalf("Migration error missing expected message. Got: %s", err)
199+
}
200+
201+
// Connection should still be usable after failed migration
202+
var result int
203+
if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT 123").Scan(&result); err != nil {
204+
t.Fatalf("Unexpected error, want connection to be usable. Got: %s", err)
205+
}
206+
if result != 123 {
207+
t.Fatalf("Unexpected result, want 123. Got: %d", result)
208+
}
209+
})
210+
}
211+
170212
func TestMultipleStatementsInMultiStatementMode(t *testing.T) {
171213
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
172214
ip, port, err := c.FirstPort()

database/pgx/v5/pgx.go

+15-3
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ func (p *Postgres) runStatement(statement []byte) error {
283283
return nil
284284
}
285285
if _, err := p.conn.ExecContext(ctx, query); err != nil {
286-
286+
var migrationErr error
287287
if pgErr, ok := err.(*pgconn.PgError); ok {
288288
var line uint
289289
var col uint
@@ -296,9 +296,21 @@ func (p *Postgres) runStatement(statement []byte) error {
296296
if pgErr.Detail != "" {
297297
message = fmt.Sprintf("%s, %s", message, pgErr.Detail)
298298
}
299-
return database.Error{OrigErr: err, Err: message, Query: statement, Line: line}
299+
migrationErr = database.Error{OrigErr: err, Err: message, Query: statement, Line: line}
300+
} else {
301+
migrationErr = database.Error{OrigErr: err, Err: "migration failed", Query: statement}
302+
}
303+
304+
// For safety, always issue a rollback on error. In multi-statement
305+
// mode, this is necessary to make sure that the connection is not left
306+
// in an aborted state. In single-statement mode, this will be a no-op
307+
// outside of the implicit transaction block that was already rolled
308+
// back.
309+
if _, rollbackErr := p.conn.ExecContext(ctx, "ROLLBACK"); rollbackErr != nil {
310+
rollbackErr = fmt.Errorf("failed to rollback migration tx: %w", rollbackErr)
311+
return multierror.Append(migrationErr, rollbackErr)
300312
}
301-
return database.Error{OrigErr: err, Err: "migration failed", Query: statement}
313+
return migrationErr
302314
}
303315
return nil
304316
}

database/postgres/TUTORIAL.md

+6-1
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,12 @@ Make sure to check if your database changed as expected in this case as well.
6868

6969
## Database transactions
7070

71+
By default, all the statements in the migration file will be run inside one
72+
implicit transaction. However, you can also use `BEGIN` and `COMMIT` statements
73+
to explicitly control what transactions your migration uses (for example, if you
74+
want to break your migration into multiple transactions to avoid holding onto a
75+
lock for too long).
76+
7177
To show database transactions usage, let's create another set of migrations by running:
7278
```
7379
migrate create -ext sql -dir db/migrations -seq add_mood_to_users
@@ -76,7 +82,6 @@ Again, it should create for us two migrations files:
7682
- 000002_add_mood_to_users.down.sql
7783
- 000002_add_mood_to_users.up.sql
7884

79-
In Postgres, when we want our queries to be done in a transaction, we need to wrap it with `BEGIN` and `COMMIT` commands.
8085
In our example, we are going to add a column to our database that can only accept enumerable values or NULL.
8186
Migration up:
8287
```

database/postgres/postgres.go

+15-2
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ func (p *Postgres) runStatement(statement []byte) error {
296296
return nil
297297
}
298298
if _, err := p.conn.ExecContext(ctx, query); err != nil {
299+
var migrationErr error
299300
if pgErr, ok := err.(*pq.Error); ok {
300301
var line uint
301302
var col uint
@@ -312,9 +313,21 @@ func (p *Postgres) runStatement(statement []byte) error {
312313
if pgErr.Detail != "" {
313314
message = fmt.Sprintf("%s, %s", message, pgErr.Detail)
314315
}
315-
return database.Error{OrigErr: err, Err: message, Query: statement, Line: line}
316+
migrationErr = database.Error{OrigErr: err, Err: message, Query: statement, Line: line}
317+
} else {
318+
migrationErr = database.Error{OrigErr: err, Err: "migration failed", Query: statement}
316319
}
317-
return database.Error{OrigErr: err, Err: "migration failed", Query: statement}
320+
321+
// For safety, always issue a rollback on error. In multi-statement
322+
// mode, this is necessary to make sure that the connection is not left
323+
// in an aborted state. In single-statement mode, this will be a no-op
324+
// outside of the implicit transaction block that was already rolled
325+
// back.
326+
if _, rollbackErr := p.conn.ExecContext(ctx, "ROLLBACK"); rollbackErr != nil {
327+
rollbackErr = fmt.Errorf("failed to rollback migration tx: %w", rollbackErr)
328+
return multierror.Append(migrationErr, rollbackErr)
329+
}
330+
return migrationErr
318331
}
319332
return nil
320333
}

0 commit comments

Comments
 (0)