Skip to content

Commit

Permalink
feat(experimental): keep Store internal and add Tablename method (#624)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfridman authored Oct 28, 2023
1 parent 4ec43df commit d59dd9f
Show file tree
Hide file tree
Showing 9 changed files with 130 additions and 85 deletions.
32 changes: 16 additions & 16 deletions database/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,16 @@ const (
DialectTiDB Dialect = "tidb"
DialectVertica Dialect = "vertica"
DialectYdB Dialect = "ydb"

// DialectCustom is a special dialect that allows users to provide their own [Store]
// implementation when constructing a [goose.Provider].
DialectCustom Dialect = "custom"
)

// NewStore returns a new [Store] backed by the given dialect.
// NewStore returns a new [Store] implementation for the given dialect.
func NewStore(dialect Dialect, tablename string) (Store, error) {
if tablename == "" {
return nil, errors.New("tablename must not be empty")
return nil, errors.New("table name must not be empty")
}
if dialect == "" {
return nil, errors.New("dialect must not be empty")
}
if dialect == DialectCustom {
return nil, errors.New("dialect must not be custom")
}
lookup := map[Dialect]dialectquery.Querier{
DialectClickHouse: &dialectquery.Clickhouse{},
DialectMSSQL: &dialectquery.Sqlserver{},
Expand Down Expand Up @@ -66,6 +59,12 @@ type store struct {

var _ Store = (*store)(nil)

func (s *store) private() {}

func (s *store) Tablename() string {
return s.tablename
}

func (s *store) CreateVersionTable(ctx context.Context, db DBTxConn) error {
q := s.querier.CreateTable(s.tablename)
if _, err := db.ExecContext(ctx, q); err != nil {
Expand All @@ -74,14 +73,15 @@ func (s *store) CreateVersionTable(ctx context.Context, db DBTxConn) error {
return nil
}

func (s *store) InsertOrDelete(ctx context.Context, db DBTxConn, direction bool, version int64) error {
if direction {
q := s.querier.InsertVersion(s.tablename)
if _, err := db.ExecContext(ctx, q, version, true); err != nil {
return fmt.Errorf("failed to insert version %d: %w", version, err)
}
return nil
func (s *store) Insert(ctx context.Context, db DBTxConn, req InsertRequest) error {
q := s.querier.InsertVersion(s.tablename)
if _, err := db.ExecContext(ctx, q, req.Version, true); err != nil {
return fmt.Errorf("failed to insert version %d: %w", req.Version, err)
}
return nil
}

func (s *store) Delete(ctx context.Context, db DBTxConn, version int64) error {
q := s.querier.DeleteVersion(s.tablename)
if _, err := db.ExecContext(ctx, q, version); err != nil {
return fmt.Errorf("failed to delete version %d: %w", version, err)
Expand Down
23 changes: 20 additions & 3 deletions database/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,18 @@ import (
// Each database dialect requires a specific implementation of this interface. A dialect represents
// a set of SQL statements specific to a particular database system.
type Store interface {
// Tablename is the version table used to record applied migrations. Must not be empty.
Tablename() string

// CreateVersionTable creates the version table. This table is used to record applied
// migrations.
CreateVersionTable(ctx context.Context, db DBTxConn) error

// InsertOrDelete inserts or deletes a version id from the version table. If direction is true,
// insert the version id. If direction is false, delete the version id.
InsertOrDelete(ctx context.Context, db DBTxConn, direction bool, version int64) error
// Insert inserts a version id into the version table.
Insert(ctx context.Context, db DBTxConn, req InsertRequest) error

// Delete deletes a version id from the version table.
Delete(ctx context.Context, db DBTxConn, version int64) error

// GetMigration retrieves a single migration by version id. This method may return the raw sql
// error if the query fails so the caller can assert for errors such as [sql.ErrNoRows].
Expand All @@ -26,6 +31,18 @@ type Store interface {
// ListMigrations retrieves all migrations sorted in descending order by id or timestamp. If
// there are no migrations, return empty slice with no error.
ListMigrations(ctx context.Context, db DBTxConn) ([]*ListMigrationsResult, error)

// TODO(mf): remove this method once the Provider is public and a custom Store can be used.
private()
}

type InsertRequest struct {
Version int64

// TODO(mf): in the future, we maybe want to expand this struct so implementors can store
// additional information. See the following issues for more information:
// - https://github.com/pressly/goose/issues/422
// - https://github.com/pressly/goose/issues/288
}

type GetMigrationResult struct {
Expand Down
21 changes: 11 additions & 10 deletions database/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ func TestDialectStore(t *testing.T) {
// Test empty dialect.
_, err = database.NewStore("", "foo")
check.HasError(t, err)
_, err = database.NewStore(database.DialectCustom, "foo")
check.HasError(t, err)
})
t.Run("postgres", func(t *testing.T) {
if testing.Short() {
Expand Down Expand Up @@ -69,9 +67,12 @@ func TestDialectStore(t *testing.T) {
check.NoError(t, err)
err = store.CreateVersionTable(context.Background(), db)
check.NoError(t, err)
check.NoError(t, store.InsertOrDelete(context.Background(), db, true, 1))
check.NoError(t, store.InsertOrDelete(context.Background(), db, true, 3))
check.NoError(t, store.InsertOrDelete(context.Background(), db, true, 2))
insert := func(db *sql.DB, version int64) error {
return store.Insert(context.Background(), db, database.InsertRequest{Version: version})
}
check.NoError(t, insert(db, 1))
check.NoError(t, insert(db, 3))
check.NoError(t, insert(db, 2))
res, err := store.ListMigrations(context.Background(), db)
check.NoError(t, err)
check.Number(t, len(res), 3)
Expand Down Expand Up @@ -124,7 +125,7 @@ func testStore(
// Insert 5 migrations in addition to the zero migration.
for i := 0; i < 6; i++ {
err = runConn(ctx, db, func(conn *sql.Conn) error {
return store.InsertOrDelete(ctx, conn, true, int64(i))
return store.Insert(ctx, conn, database.InsertRequest{Version: int64(i)})
})
check.NoError(t, err)
}
Expand All @@ -145,7 +146,7 @@ func testStore(
// Delete 3 migrations backwards
for i := 5; i >= 3; i-- {
err = runConn(ctx, db, func(conn *sql.Conn) error {
return store.InsertOrDelete(ctx, conn, false, int64(i))
return store.Delete(ctx, conn, int64(i))
})
check.NoError(t, err)
}
Expand Down Expand Up @@ -179,16 +180,16 @@ func testStore(

// 1. *sql.Tx
err = runTx(ctx, db, func(tx *sql.Tx) error {
return store.InsertOrDelete(ctx, tx, false, 2)
return store.Delete(ctx, tx, 2)
})
check.NoError(t, err)
// 2. *sql.Conn
err = runConn(ctx, db, func(conn *sql.Conn) error {
return store.InsertOrDelete(ctx, conn, false, 1)
return store.Delete(ctx, conn, 1)
})
check.NoError(t, err)
// 3. *sql.DB
err = store.InsertOrDelete(ctx, db, false, 0)
err = store.Delete(ctx, db, 0)
check.NoError(t, err)

// List migrations. There should be none.
Expand Down
27 changes: 20 additions & 7 deletions internal/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ import (
//
// The caller is responsible for matching the database dialect with the database/sql driver. For
// example, if the database dialect is "postgres", the database/sql driver could be
// github.com/lib/pq or github.com/jackc/pgx.
// github.com/lib/pq or github.com/jackc/pgx. Each dialect has a corresponding [database.Dialect]
// constant backed by a default [database.Store] implementation. For more advanced use cases, such
// as using a custom table name or supplying a custom store implementation, see [WithStore].
//
// fsys is the filesystem used to read the migration files, but may be nil. Most users will want to
// use [os.DirFS], os.DirFS("path/to/migrations"), to read migrations from the local filesystem.
Expand Down Expand Up @@ -44,13 +46,24 @@ func NewProvider(dialect database.Dialect, db *sql.DB, fsys fs.FS, opts ...Provi
return nil, err
}
}
// Set defaults after applying user-supplied options so option funcs can check for empty values.
if cfg.tableName == "" {
cfg.tableName = DefaultTablename
if dialect == "" && cfg.store == nil {
return nil, errors.New("dialect must not be empty")
}
store, err := database.NewStore(dialect, cfg.tableName)
if err != nil {
return nil, err
if dialect != "" && cfg.store != nil {
return nil, errors.New("cannot set both dialect and store")
}
var store database.Store
if dialect != "" {
var err error
store, err = database.NewStore(dialect, DefaultTablename)
if err != nil {
return nil, err
}
} else {
store = cfg.store
}
if store.Tablename() == "" {
return nil, errors.New("invalid store implementation: table name must not be empty")
}
// Collect migrations from the filesystem and merge with registered migrations.
//
Expand Down
42 changes: 31 additions & 11 deletions internal/provider/provider_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"errors"
"fmt"

"github.com/pressly/goose/v3/database"
"github.com/pressly/goose/v3/lock"
)

Expand All @@ -21,18 +22,36 @@ type ProviderOption interface {
apply(*config) error
}

// WithTableName sets the name of the database table used to track history of applied migrations.
// WithStore configures the provider with a custom [database.Store] implementation.
//
// If WithTableName is not called, the default value is "goose_db_version".
func WithTableName(name string) ProviderOption {
// By default, the provider uses the [database.NewStore] function to create a store backed by the
// given dialect. However, this option allows users to provide their own implementation or call
// [database.NewStore] with custom options, such as setting the table name.
//
// Example:
//
// // Create a store with a custom table name.
// store, err := database.NewStore(database.DialectPostgres, "my_custom_table_name")
// if err != nil {
// return err
// }
// // Create a provider with the custom store.
// provider, err := goose.NewProvider("", db, nil, goose.WithStore(store))
// if err != nil {
// return err
// }
func WithStore(store database.Store) ProviderOption {
return configFunc(func(c *config) error {
if c.tableName != "" {
return fmt.Errorf("table already set to %q", c.tableName)
if c.store != nil {
return fmt.Errorf("store already set: %T", c.store)
}
if name == "" {
return errors.New("table must not be empty")
if store == nil {
return errors.New("store must not be nil")
}
c.tableName = name
if store.Tablename() == "" {
return errors.New("store implementation must set the table name")
}
c.store = store
return nil
})
}
Expand Down Expand Up @@ -148,9 +167,10 @@ func WithNoVersioning(b bool) ProviderOption {
}

type config struct {
tableName string
verbose bool
excludes map[string]bool
store database.Store

verbose bool
excludes map[string]bool

// Go migrations registered by the user. These will be merged/resolved with migrations from the
// filesystem and init() functions.
Expand Down
40 changes: 21 additions & 19 deletions internal/provider/provider_options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"testing"
"testing/fstest"

"github.com/pressly/goose/v3/database"
"github.com/pressly/goose/v3/internal/check"
"github.com/pressly/goose/v3/internal/provider"
_ "modernc.org/sqlite"
Expand All @@ -29,38 +30,39 @@ func TestNewProvider(t *testing.T) {
_, err = provider.NewProvider("unknown-dialect", db, fsys)
check.HasError(t, err)
// Nil db not allowed
_, err = provider.NewProvider("sqlite3", nil, fsys)
_, err = provider.NewProvider(database.DialectSQLite3, nil, fsys)
check.HasError(t, err)
// Nil fsys not allowed
_, err = provider.NewProvider("sqlite3", db, nil)
_, err = provider.NewProvider(database.DialectSQLite3, db, nil)
check.HasError(t, err)
// Duplicate table name not allowed
_, err = provider.NewProvider("sqlite3", db, fsys,
provider.WithTableName("foo"),
provider.WithTableName("bar"),
)
// Nil store not allowed
_, err = provider.NewProvider(database.DialectSQLite3, db, nil, provider.WithStore(nil))
check.HasError(t, err)
// Cannot set both dialect and store
store, err := database.NewStore(database.DialectSQLite3, "custom_table")
check.NoError(t, err)
_, err = provider.NewProvider(database.DialectSQLite3, db, nil, provider.WithStore(store))
check.HasError(t, err)
check.Equal(t, `table already set to "foo"`, err.Error())
// Empty table name not allowed
_, err = provider.NewProvider("sqlite3", db, fsys,
provider.WithTableName(""),
// Multiple stores not allowed
_, err = provider.NewProvider(database.DialectSQLite3, db, nil,
provider.WithStore(store),
provider.WithStore(store),
)
check.HasError(t, err)
check.Equal(t, "table must not be empty", err.Error())
})
t.Run("valid", func(t *testing.T) {
// Valid dialect, db, and fsys allowed
_, err = provider.NewProvider("sqlite3", db, fsys)
check.NoError(t, err)
// Valid dialect, db, fsys, and table name allowed
_, err = provider.NewProvider("sqlite3", db, fsys,
provider.WithTableName("foo"),
)
_, err = provider.NewProvider(database.DialectSQLite3, db, fsys)
check.NoError(t, err)
// Valid dialect, db, fsys, and verbose allowed
_, err = provider.NewProvider("sqlite3", db, fsys,
_, err = provider.NewProvider(database.DialectSQLite3, db, fsys,
provider.WithVerbose(testing.Verbose()),
)
check.NoError(t, err)
// Custom store allowed
store, err := database.NewStore(database.DialectSQLite3, "custom_table")
check.NoError(t, err)
_, err = provider.NewProvider("", db, nil, provider.WithStore(store))
check.HasError(t, err)
})
}
4 changes: 2 additions & 2 deletions internal/provider/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func TestProvider(t *testing.T) {
db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db"))
check.NoError(t, err)
t.Run("empty", func(t *testing.T) {
_, err := provider.NewProvider("sqlite3", db, fstest.MapFS{})
_, err := provider.NewProvider(database.DialectSQLite3, db, fstest.MapFS{})
check.HasError(t, err)
check.Bool(t, errors.Is(err, provider.ErrNoMigrations), true)
})
Expand All @@ -31,7 +31,7 @@ func TestProvider(t *testing.T) {
}
fsys, err := fs.Sub(mapFS, "migrations")
check.NoError(t, err)
p, err := provider.NewProvider("sqlite3", db, fsys)
p, err := provider.NewProvider(database.DialectSQLite3, db, fsys)
check.NoError(t, err)
sources := p.ListSources()
check.Equal(t, len(sources), 2)
Expand Down
12 changes: 9 additions & 3 deletions internal/provider/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,10 @@ func (p *Provider) runIndividually(
if p.cfg.noVersioning {
return nil
}
return p.store.InsertOrDelete(ctx, tx, direction, m.Source.Version)
if direction {
return p.store.Insert(ctx, tx, database.InsertRequest{Version: m.Source.Version})
}
return p.store.Delete(ctx, tx, m.Source.Version)
})
}
// Run the migration outside of a transaction.
Expand All @@ -268,7 +271,10 @@ func (p *Provider) runIndividually(
if p.cfg.noVersioning {
return nil
}
return p.store.InsertOrDelete(ctx, conn, direction, m.Source.Version)
if direction {
return p.store.Insert(ctx, conn, database.InsertRequest{Version: m.Source.Version})
}
return p.store.Delete(ctx, conn, m.Source.Version)
}

// beginTx begins a transaction and runs the given function. If the function returns an error, the
Expand Down Expand Up @@ -367,7 +373,7 @@ func (p *Provider) ensureVersionTable(ctx context.Context, conn *sql.Conn) (retE
if p.cfg.noVersioning {
return nil
}
return p.store.InsertOrDelete(ctx, tx, true, 0)
return p.store.Insert(ctx, tx, database.InsertRequest{Version: 0})
})
}

Expand Down
Loading

0 comments on commit d59dd9f

Please sign in to comment.