diff --git a/database/contracts.go b/database/contracts.go index 6da723a8..0dcf3a27 100644 --- a/database/contracts.go +++ b/database/contracts.go @@ -34,6 +34,11 @@ type IDer interface { // EntityFactoryFunc knows how to create an Entity. type EntityFactoryFunc func() Entity +type EntityConstraint[T any] interface { + Entity + *T +} + // Upserter implements the Upsert method, // which returns a part of the object for ON DUPLICATE KEY UPDATE. type Upserter interface { diff --git a/database/db.go b/database/db.go index 39419dc1..325d88bd 100644 --- a/database/db.go +++ b/database/db.go @@ -22,6 +22,7 @@ import ( "go.uber.org/zap/zapcore" "golang.org/x/sync/errgroup" "golang.org/x/sync/semaphore" + _ "modernc.org/sqlite" "net" "net/url" "slices" @@ -39,6 +40,7 @@ type DB struct { Options *Options addr string + queryBuilder QueryBuilder columnMap ColumnMap logger *logging.Logger tableSemaphores map[string]*semaphore.Weighted @@ -256,6 +258,7 @@ func NewDbFromConfig(c *Config, logger *logging.Logger, connectorCallbacks Retry return &DB{ DB: db, Options: &c.Options, + queryBuilder: NewQueryBuilder(db.DriverName()), columnMap: NewColumnMap(db.Mapper), addr: addr, logger: logger, @@ -932,6 +935,10 @@ func (db *DB) Log(ctx context.Context, query string, counter *com.Counter) perio })) } +func (db *DB) QueryBuilder() QueryBuilder { + return db.queryBuilder +} + var ( // Assert TxOrDB interface compliance of the DB and sqlx.Tx types. _ TxOrDB = (*DB)(nil) diff --git a/database/delete.go b/database/delete.go new file mode 100644 index 00000000..084c7bef --- /dev/null +++ b/database/delete.go @@ -0,0 +1,213 @@ +package database + +import ( + "context" + "fmt" + "github.com/icinga/icinga-go-library/backoff" + "github.com/icinga/icinga-go-library/com" + "github.com/icinga/icinga-go-library/retry" + "github.com/jmoiron/sqlx" + "github.com/pkg/errors" + "golang.org/x/sync/errgroup" + "golang.org/x/sync/semaphore" + "reflect" + "time" +) + +// DeleteStatement is the interface for building DELETE statements. +type DeleteStatement interface { + // From sets the table name for the DELETE statement. + // Overrides the table name provided by the entity. + From(table string) DeleteStatement + + // SetWhere sets the where clause for the DELETE statement. + SetWhere(where string) DeleteStatement + + // Entity returns the entity associated with the DELETE statement. + Entity() Entity + + // Table returns the table name for the DELETE statement. + Table() string + + Where() string +} + +// NewDeleteStatement returns a new deleteStatement for the given entity. +func NewDeleteStatement(entity Entity) DeleteStatement { + return &deleteStatement{ + entity: entity, + } +} + +// deleteStatement is the default implementation of the DeleteStatement interface. +type deleteStatement struct { + entity Entity + table string + where string +} + +func (d *deleteStatement) From(table string) DeleteStatement { + d.table = table + + return d +} + +func (d *deleteStatement) SetWhere(where string) DeleteStatement { + d.where = where + + return d +} + +func (d *deleteStatement) Entity() Entity { + return d.entity +} + +func (d *deleteStatement) Table() string { + return d.table +} + +func (d *deleteStatement) Where() string { + return d.where +} + +// DeleteOption is a functional option for DeleteStreamed(). +type DeleteOption func(opts *deleteOptions) + +// WithDeleteStatement sets the DELETE statement to be used for deleting entities. +func WithDeleteStatement(stmt DeleteStatement) DeleteOption { + return func(opts *deleteOptions) { + opts.stmt = stmt + } +} + +// WithOnDelete sets the callbacks for a successful DELETE operation. +func WithOnDelete(onDelete ...OnSuccess[any]) DeleteOption { + return func(opts *deleteOptions) { + opts.onDelete = append(opts.onDelete, onDelete...) + } +} + +// deleteOptions stores the options for DeleteStreamed. +type deleteOptions struct { + stmt DeleteStatement + onDelete []OnSuccess[any] +} + +// DeleteStreamed deletes entities from the given channel from the database. +func DeleteStreamed( + ctx context.Context, + db *DB, + entityType Entity, + entities <-chan any, + options ...DeleteOption, +) error { + opts := &deleteOptions{} + for _, option := range options { + option(opts) + } + + first, forward, err := com.CopyFirst(ctx, entities) + if err != nil { + return errors.Wrap(err, "can't copy first entity") + } + + sem := db.GetSemaphoreForTable(TableName(entityType)) + + var stmt string + + if opts.stmt != nil { + stmt, err = db.QueryBuilder().DeleteStatement(opts.stmt) + if err != nil { + return err + } + } else { + stmt, err = db.QueryBuilder().DeleteStatement(NewDeleteStatement(entityType)) + if err != nil { + return err + } + } + + switch reflect.TypeOf(first).Kind() { + case reflect.Struct, reflect.Map: + return namedBulkExec(ctx, db, stmt, db.Options.MaxPlaceholdersPerStatement, sem, forward, com.NeverSplit[any], opts.onDelete...) + default: + return bulkExec(ctx, db, stmt, db.Options.MaxPlaceholdersPerStatement, sem, forward, opts.onDelete...) + } +} + +func bulkExec( + ctx context.Context, db *DB, query string, count int, sem *semaphore.Weighted, arg <-chan any, onSuccess ...OnSuccess[any], +) error { + var counter com.Counter + defer db.Log(ctx, query, &counter).Stop() + + g, ctx := errgroup.WithContext(ctx) + // Use context from group. + bulk := com.Bulk(ctx, arg, count, com.NeverSplit[any]) + + g.Go(func() error { + g, ctx := errgroup.WithContext(ctx) + + for b := range bulk { + if err := sem.Acquire(ctx, 1); err != nil { + return errors.Wrap(err, "can't acquire semaphore") + } + + g.Go(func(b []any) func() error { + return func() error { + defer sem.Release(1) + + return retry.WithBackoff( + ctx, + func(context.Context) error { + var valCollection []any + + for _, v := range b { + val := reflect.ValueOf(v) + if val.Kind() == reflect.Slice { + for i := 0; i < val.Len(); i++ { + valCollection = append(valCollection, val.Index(i).Interface()) + } + } else { + valCollection = append(valCollection, val.Interface()) + } + } + + stmt, args, err := sqlx.In(query, valCollection) + if err != nil { + return fmt.Errorf( + "%w: %w", + retry.ErrNotRetryable, + errors.Wrapf(err, "can't build placeholders for %q", query), + ) + } + + stmt = db.Rebind(stmt) + _, err = db.ExecContext(ctx, stmt, args...) + if err != nil { + return CantPerformQuery(err, query) + } + + counter.Add(uint64(len(b))) + + for _, onSuccess := range onSuccess { + if err := onSuccess(ctx, b); err != nil { + return err + } + } + + return nil + }, + retry.Retryable, + backoff.NewExponentialWithJitter(1*time.Millisecond, 1*time.Second), + db.GetDefaultRetrySettings(), + ) + } + }(b)) + } + + return g.Wait() + }) + + return g.Wait() +} diff --git a/database/example_upsert_test.go b/database/example_upsert_test.go new file mode 100644 index 00000000..77b08821 --- /dev/null +++ b/database/example_upsert_test.go @@ -0,0 +1,324 @@ +package database + +import ( + "context" + "fmt" + "github.com/icinga/icinga-go-library/com" + "golang.org/x/sync/errgroup" + "time" +) + +func ExampleUpsertStreamed() { + var ( + testEntites = []User{ + {Id: 1, Name: "test1", Age: 10, Email: "test1@test.com"}, + {Id: 2, Name: "test2", Age: 20, Email: "test2@test.com"}, + } + testSelect = &[]User{} + g, ctx = errgroup.WithContext(context.Background()) + entities = make(chan User, len(testEntites)) + logs = getTestLogging() + db = getTestDb(logs) + log = logs.GetLogger() + err error + ) + initTestDb(db) + + g.Go(func() error { + return UpsertStreamed(ctx, db, entities) + }) + + for _, entity := range testEntites { + entities <- entity + } + + close(entities) + time.Sleep(10 * time.Millisecond) + + if err = db.Select(testSelect, "SELECT * FROM user"); err != nil { + log.Fatalf("cannot select from db: %v", err) + } + + fmt.Println(*testSelect) + + if err = g.Wait(); err != nil { + log.Fatalf("error while upserting entities: %v", err) + } + + _ = db.Close() + + // Output: + // [{1 test1 10 test1@test.com} {2 test2 20 test2@test.com}] +} + +func ExampleUpsertStreamedWithStatement() { + var ( + testEntites = []User{ + {Id: 1, Name: "test1"}, + {Id: 2, Name: "test2"}, + } + testSelect = &[]User{} + g, ctx = errgroup.WithContext(context.Background()) + entities = make(chan User, len(testEntites)) + logs = getTestLogging() + db = getTestDb(logs) + log = logs.GetLogger() + stmt = NewUpsertStatement(User{}).SetColumns("id", "name") + err error + ) + initTestDb(db) + + g.Go(func() error { + return UpsertStreamed(ctx, db, entities, WithUpsertStatement(stmt)) + }) + + for _, entity := range testEntites { + entities <- entity + } + + close(entities) + time.Sleep(10 * time.Millisecond) + + if err = db.Select(testSelect, "SELECT * FROM user"); err != nil { + log.Fatalf("cannot select from db: %v", err) + } + + fmt.Println(*testSelect) + + if err = g.Wait(); err != nil { + log.Fatalf("error while upserting entities: %v", err) + } + + _ = db.Close() + + // Output: + // [{1 test1 0 } {2 test2 0 }] +} + +func ExampleUpsertStreamedWithOnUpsert() { + var ( + testEntites = []User{ + {Id: 1, Name: "test1", Age: 10, Email: "test1@test.com"}, + {Id: 2, Name: "test2", Age: 20, Email: "test2@test.com"}, + } + callback = func(ctx context.Context, affectedRows []any) (err error) { + fmt.Printf("number of affected rows: %d\n", len(affectedRows)) + return nil + } + testSelect = &[]User{} + g, ctx = errgroup.WithContext(context.Background()) + entities = make(chan User, len(testEntites)) + logs = getTestLogging() + db = getTestDb(logs) + log = logs.GetLogger() + err error + ) + initTestDb(db) + + g.Go(func() error { + return UpsertStreamed(ctx, db, entities, WithOnUpsert(callback)) + }) + + for _, entity := range testEntites { + entities <- entity + } + + time.Sleep(1 * time.Second) + close(entities) + + if err = db.Select(testSelect, "SELECT * FROM user"); err != nil { + log.Fatalf("cannot select from db: %v", err) + } + + fmt.Println(*testSelect) + + if err = g.Wait(); err != nil { + log.Fatalf("error while upserting entities: %v", err) + } + + _ = db.Close() + + // Output: + // number of affected rows: 2 + // [{1 test1 10 test1@test.com} {2 test2 20 test2@test.com}] +} + +func ExampleNamedBulkUpsert() { + var ( + testEntites = []User{ + {Id: 1, Name: "test1", Age: 10, Email: "test1@test.com"}, + {Id: 2, Name: "test2", Age: 20, Email: "test2@test.com"}, + } + testSelect = &[]User{} + g, ctx = errgroup.WithContext(context.Background()) + entities = make(chan Entity, len(testEntites)) + logs = getTestLogging() + db = getTestDb(logs) + log = logs.GetLogger() + sem = db.GetSemaphoreForTable(TableName(User{})) + err error + ) + initTestDb(db) + + stmt, placeholders, err := db.QueryBuilder().UpsertStatement(NewUpsertStatement(User{})) + if err != nil { + log.Fatalf("error while building upsert statement: %v", err) + } + + g.Go(func() error { + return db.NamedBulkExec(ctx, stmt, placeholders, sem, entities, com.NeverSplit) + }) + + for _, entity := range testEntites { + entities <- entity + } + + time.Sleep(1 * time.Second) + close(entities) + + if err = db.Select(testSelect, "SELECT * FROM user"); err != nil { + log.Fatalf("cannot select from db: %v", err) + } + + fmt.Println(*testSelect) + + if err = g.Wait(); err != nil { + log.Fatalf("error while upserting entities: %v", err) + } + + _ = db.Close() + + // Output: + // [{1 test1 10 test1@test.com} {2 test2 20 test2@test.com}] +} + +func ExampleNamedBulkUpsertWithOnUpsert() { + var ( + testEntites = []User{ + {Id: 1, Name: "test1", Age: 10, Email: "test1@test.com"}, + {Id: 2, Name: "test2", Age: 20, Email: "test2@test.com"}, + } + testSelect = &[]User{} + callback = func(ctx context.Context, affectedRows []Entity) (err error) { + fmt.Printf("number of affected rows: %d\n", len(affectedRows)) + return nil + } + g, ctx = errgroup.WithContext(context.Background()) + entities = make(chan Entity, len(testEntites)) + logs = getTestLogging() + db = getTestDb(logs) + log = logs.GetLogger() + sem = db.GetSemaphoreForTable(TableName(User{})) + err error + ) + initTestDb(db) + + stmt, placeholders, err := db.QueryBuilder().UpsertStatement(NewUpsertStatement(User{})) + if err != nil { + log.Fatalf("error while building upsert statement: %v", err) + } + + g.Go(func() error { + return db.NamedBulkExec(ctx, stmt, placeholders, sem, entities, com.NeverSplit, callback) + }) + + for _, entity := range testEntites { + entities <- entity + } + + time.Sleep(1 * time.Second) + close(entities) + + if err = db.Select(testSelect, "SELECT * FROM user"); err != nil { + log.Fatalf("cannot select from db: %v", err) + } + + fmt.Println(*testSelect) + + if err = g.Wait(); err != nil { + log.Fatalf("error while upserting entities: %v", err) + } + + _ = db.Close() + + // Output: + // number of affected rows: 2 + // [{1 test1 10 test1@test.com} {2 test2 20 test2@test.com}] +} + +func ExampleNamedExecUpsert() { + var ( + testEntites = []User{ + {Id: 1, Name: "test1", Age: 10, Email: "test1@test.com"}, + {Id: 2, Name: "test2", Age: 20, Email: "test2@test.com"}, + } + testSelect = &[]User{} + ctx = context.Background() + logs = getTestLogging() + db = getTestDb(logs) + log = logs.GetLogger() + err error + ) + initTestDb(db) + + stmt, _, err := db.QueryBuilder().UpsertStatement(NewUpsertStatement(User{})) + if err != nil { + log.Fatalf("error while building upsert statement: %v", err) + } + + for _, entity := range testEntites { + if _, err = db.NamedExecContext(ctx, stmt, entity); err != nil { + log.Fatalf("error while upserting entity: %v", err) + } + } + + if err = db.Select(testSelect, "SELECT * FROM user"); err != nil { + log.Fatalf("cannot select from db: %v", err) + } + + fmt.Println(*testSelect) + + _ = db.Close() + + // Output: + // [{1 test1 10 test1@test.com} {2 test2 20 test2@test.com}] +} + +func ExampleExecUpsert() { + var ( + testEntites = [][]any{ + {1, "test1", 10, "test1@test.com"}, + {2, "test2", 20, "test2@test.com"}, + } + testSelect = &[]User{} + stmt = `INSERT INTO user ("id", "name", "age", "email") VALUES (?, ?, ?, ?) ON CONFLICT DO UPDATE SET "name" = EXCLUDED."name", "age" = EXCLUDED."age", "email" = EXCLUDED."email"` + ctx = context.Background() + logs = getTestLogging() + db = getTestDb(logs) + log = logs.GetLogger() + err error + ) + initTestDb(db) + + //stmt, _, err := db.QueryBuilder().UpsertStatement(NewUpsertStatement(User{})) + //if err != nil { + // log.Fatalf("error while building upsert statement: %v", err) + //} + + for _, entity := range testEntites { + if _, err = db.ExecContext(ctx, stmt, entity...); err != nil { + log.Fatalf("error while upserting entity: %v", err) + } + } + + if err = db.Select(testSelect, "SELECT * FROM user"); err != nil { + log.Fatalf("cannot select from db: %v", err) + } + + fmt.Println(*testSelect) + + _ = db.Close() + + // Output: + // [{1 test1 10 test1@test.com} {2 test2 20 test2@test.com}] +} diff --git a/database/insert.go b/database/insert.go new file mode 100644 index 00000000..2df546e2 --- /dev/null +++ b/database/insert.go @@ -0,0 +1,204 @@ +package database + +import "context" + +// InsertStatement is the interface for building INSERT statements. +type InsertStatement interface { + // Into sets the table name for the INSERT statement. + // Overrides the table name provided by the entity. + Into(table string) InsertStatement + + // SetColumns sets the columns to be inserted. + SetColumns(columns ...string) InsertStatement + + // SetExcludedColumns sets the columns to be excluded from the INSERT statement. + // Excludes also columns set by SetColumns. + SetExcludedColumns(columns ...string) InsertStatement + + // Entity returns the entity associated with the INSERT statement. + Entity() Entity + + // Table returns the table name for the INSERT statement. + Table() string + + // Columns returns the columns to be inserted. + Columns() []string + + // ExcludedColumns returns the columns to be excluded from the INSERT statement. + ExcludedColumns() []string +} + +// NewInsertStatement returns a new insertStatement for the given entity. +func NewInsertStatement(entity Entity) InsertStatement { + return &insertStatement{ + entity: entity, + } +} + +// insertStatement is the default implementation of the InsertStatement interface. +type insertStatement struct { + entity Entity + table string + columns []string + excludedColumns []string +} + +func (i *insertStatement) Into(table string) InsertStatement { + i.table = table + + return i +} + +func (i *insertStatement) SetColumns(columns ...string) InsertStatement { + i.columns = columns + + return i +} + +func (i *insertStatement) SetExcludedColumns(columns ...string) InsertStatement { + i.excludedColumns = columns + + return i +} + +func (i *insertStatement) Entity() Entity { + return i.entity +} + +func (i *insertStatement) Table() string { + return i.table +} + +func (i *insertStatement) Columns() []string { + return i.columns +} + +func (i *insertStatement) ExcludedColumns() []string { + return i.excludedColumns +} + +// InsertSelectStatement is the interface for building INSERT SELECT statements. +type InsertSelectStatement interface { + // Into sets the table name for the INSERT SELECT statement. + // Overrides the table name provided by the entity. + Into(table string) InsertSelectStatement + + // SetColumns sets the columns to be inserted. + SetColumns(columns ...string) InsertSelectStatement + + // SetExcludedColumns sets the columns to be excluded from the INSERT SELECT statement. + // Excludes also columns set by SetColumns. + SetExcludedColumns(columns ...string) InsertSelectStatement + + // SetSelect sets the SELECT statement for the INSERT SELECT statement. + SetSelect(stmt SelectStatement) InsertSelectStatement + + // Entity returns the entity associated with the INSERT SELECT statement. + Entity() Entity + + // Table returns the table name for the INSERT SELECT statement. + Table() string + + // Columns returns the columns to be inserted. + Columns() []string + + // ExcludedColumns returns the columns to be excluded from the INSERT statement. + ExcludedColumns() []string + + // Select returns the SELECT statement for the INSERT SELECT statement. + Select() SelectStatement +} + +// NewInsertSelectStatement returns a new insertSelectStatement for the given entity. +func NewInsertSelectStatement(entity Entity) InsertSelectStatement { + return &insertSelectStatement{ + entity: entity, + } +} + +// insertSelectStatement is the default implementation of the InsertSelectStatement interface. +type insertSelectStatement struct { + entity Entity + table string + columns []string + excludedColumns []string + selectStmt SelectStatement +} + +func (i *insertSelectStatement) Into(table string) InsertSelectStatement { + i.table = table + + return i +} + +func (i *insertSelectStatement) SetColumns(columns ...string) InsertSelectStatement { + i.columns = columns + + return i +} + +func (i *insertSelectStatement) SetExcludedColumns(columns ...string) InsertSelectStatement { + i.excludedColumns = columns + + return i +} + +func (i *insertSelectStatement) SetSelect(stmt SelectStatement) InsertSelectStatement { + i.selectStmt = stmt + + return i +} + +func (i *insertSelectStatement) Entity() Entity { + return i.entity +} + +func (i *insertSelectStatement) Table() string { + return i.table +} + +func (i *insertSelectStatement) Columns() []string { + return i.columns +} + +func (i *insertSelectStatement) ExcludedColumns() []string { + return i.excludedColumns +} + +func (i *insertSelectStatement) Select() SelectStatement { + return i.selectStmt +} + +// InsertOption is a functional option for InsertStreamed(). +type InsertOption func(opts *insertOptions) + +// WithInsertStatement sets the INSERT statement to be used for inserting entities. +func WithInsertStatement(stmt InsertStatement) InsertOption { + return func(opts *insertOptions) { + opts.stmt = stmt + } +} + +// WithOnInsert sets the onInsert callbacks for a successful INSERT statement. +func WithOnInsert(onInsert ...OnSuccess[any]) InsertOption { + return func(opts *insertOptions) { + opts.onInsert = append(opts.onInsert, onInsert...) + } +} + +// insertOptions stores the options for InsertStreamed. +type insertOptions struct { + stmt InsertStatement + onInsert []OnSuccess[any] +} + +// InsertStreamed inserts entities from the given channel into the database. +func InsertStreamed[T any, V EntityConstraint[T]]( + ctx context.Context, + db *DB, + entities <-chan T, + options ...InsertOption, +) error { + // TODO (jr): implement + return nil +} diff --git a/database/query_builder.go b/database/query_builder.go new file mode 100644 index 00000000..5b3f596e --- /dev/null +++ b/database/query_builder.go @@ -0,0 +1,307 @@ +package database + +import ( + "errors" + "fmt" + "github.com/icinga/icinga-go-library/strcase" + "github.com/jmoiron/sqlx/reflectx" + "slices" + "sort" + "strings" +) + +var ( + ErrUnsupportedDriver = errors.New("unsupported database driver") + ErrMissingStatementPart = errors.New("missing statement part") +) + +type QueryBuilder interface { + UpsertStatement(stmt UpsertStatement) (string, int, error) + + InsertStatement(stmt InsertStatement) string + + InsertIgnoreStatement(stmt InsertStatement) (string, error) + + InsertSelectStatement(stmt InsertSelectStatement) (string, error) + + SelectStatement(stmt SelectStatement) string + + UpdateStatement(stmt UpdateStatement) (string, error) + + UpdateAllStatement(stmt UpdateStatement) (string, error) + + DeleteStatement(stmt DeleteStatement) (string, error) + + DeleteAllStatement(stmt DeleteStatement) (string, error) + + BuildColumns(entity Entity, columns []string, excludedColumns []string) []string +} + +func NewQueryBuilder(driver string) QueryBuilder { + return &queryBuilder{ + dbDriver: driver, + columnMap: NewColumnMap(reflectx.NewMapperFunc("db", strcase.Snake)), + } +} + +func NewTestQueryBuilder(driver string) QueryBuilder { + return &queryBuilder{ + dbDriver: driver, + columnMap: NewColumnMap(reflectx.NewMapperFunc("db", strcase.Snake)), + sort: true, + } +} + +type queryBuilder struct { + dbDriver string + columnMap ColumnMap + + // Indicates whether the generated columns should be sorted in ascending order before generating the + // actual statements. This is intended for unit tests only and shouldn't be necessary for production code. + sort bool +} + +func (qb *queryBuilder) UpsertStatement(stmt UpsertStatement) (string, int, error) { + columns := qb.BuildColumns(stmt.Entity(), stmt.Columns(), stmt.ExcludedColumns()) + into := stmt.Table() + if into == "" { + into = TableName(stmt.Entity()) + } + var setFormat, clause string + switch qb.dbDriver { + case MySQL: + clause = "ON DUPLICATE KEY UPDATE" + setFormat = `"%[1]s" = VALUES("%[1]s")` + case PostgreSQL: + var constraint string + if constrainter, ok := stmt.Entity().(PgsqlOnConflictConstrainter); ok { + constraint = constrainter.PgsqlOnConflictConstraint() + } else { + constraint = "pk_" + into + } + + clause = fmt.Sprintf("ON CONFLICT ON CONSTRAINT %s DO UPDATE SET", constraint) + setFormat = `"%[1]s" = EXCLUDED."%[1]s"` + case SQLite: + clause = "ON CONFLICT DO UPDATE SET" + setFormat = `"%[1]s" = EXCLUDED."%[1]s"` + default: + return "", 0, fmt.Errorf("%w: %s", ErrUnsupportedDriver, qb.dbDriver) + } + + set := make([]string, 0, len(columns)) + for _, column := range columns { + set = append(set, fmt.Sprintf(setFormat, column)) + } + + return fmt.Sprintf( + `INSERT INTO "%s" ("%s") VALUES (%s) %s %s`, + into, + strings.Join(columns, `", "`), + fmt.Sprintf(":%s", strings.Join(columns, ", :")), + clause, + strings.Join(set, ", "), + ), len(columns), nil +} + +func (qb *queryBuilder) InsertStatement(stmt InsertStatement) string { + columns := qb.BuildColumns(stmt.Entity(), stmt.Columns(), stmt.ExcludedColumns()) + into := stmt.Table() + if into == "" { + into = TableName(stmt.Entity()) + } + + return fmt.Sprintf( + `INSERT INTO "%s" ("%s") VALUES (%s)`, + into, + strings.Join(columns, `", "`), + fmt.Sprintf(":%s", strings.Join(columns, ", :")), + ) +} + +func (qb *queryBuilder) InsertIgnoreStatement(stmt InsertStatement) (string, error) { + columns := qb.BuildColumns(stmt.Entity(), stmt.Columns(), stmt.ExcludedColumns()) + into := stmt.Table() + if into == "" { + into = TableName(stmt.Entity()) + } + + switch qb.dbDriver { + case MySQL: + return fmt.Sprintf( + `INSERT IGNORE INTO "%s" ("%s") VALUES (%s)`, + into, + strings.Join(columns, `", "`), + fmt.Sprintf(":%s", strings.Join(columns, ", :")), + ), nil + case PostgreSQL, SQLite: + return fmt.Sprintf( + `INSERT INTO "%s" ("%s") VALUES (%s) ON CONFLICT DO NOTHING`, + into, + strings.Join(columns, `", "`), + fmt.Sprintf(":%s", strings.Join(columns, ", :")), + ), nil + default: + return "", fmt.Errorf("%w: %s", ErrUnsupportedDriver, qb.dbDriver) + } +} + +func (qb *queryBuilder) InsertSelectStatement(stmt InsertSelectStatement) (string, error) { + columns := qb.BuildColumns(stmt.Entity(), stmt.Columns(), stmt.ExcludedColumns()) + + sel := stmt.Select() + if sel == nil { + return "", fmt.Errorf("%w: %s", ErrMissingStatementPart, "select statement") + } + selectStmt := qb.SelectStatement(sel) + + into := stmt.Table() + if into == "" { + into = TableName(stmt.Entity()) + } + + return fmt.Sprintf( + `INSERT INTO "%s" ("%s") %s`, + into, + strings.Join(columns, `", "`), + selectStmt, + ), nil +} + +func (qb *queryBuilder) SelectStatement(stmt SelectStatement) string { + columns := qb.BuildColumns(stmt.Entity(), stmt.Columns(), stmt.ExcludedColumns()) + + from := stmt.Table() + if from == "" { + from = TableName(stmt.Entity()) + } + + where := stmt.Where() + if where != "" { + where = fmt.Sprintf(" WHERE %s", where) + } + + return fmt.Sprintf( + `SELECT "%s" FROM "%s"%s`, + strings.Join(columns, `", "`), + from, + where, + ) +} + +func (qb *queryBuilder) UpdateStatement(stmt UpdateStatement) (string, error) { + columns := qb.BuildColumns(stmt.Entity(), stmt.Columns(), stmt.ExcludedColumns()) + + table := stmt.Table() + if table == "" { + table = TableName(stmt.Entity()) + } + + where := stmt.Where() + if where == "" { + return "", fmt.Errorf("%w: %s", ErrMissingStatementPart, "where statement - use UpdateAllStatement() instead") + } + + var set []string + + for _, col := range columns { + set = append(set, fmt.Sprintf(`"%[1]s" = :%[1]s`, col)) + } + + return fmt.Sprintf( + `UPDATE "%s" SET %s WHERE %s`, + table, + strings.Join(set, ", "), + where, + ), nil +} + +func (qb *queryBuilder) UpdateAllStatement(stmt UpdateStatement) (string, error) { + columns := qb.BuildColumns(stmt.Entity(), stmt.Columns(), stmt.ExcludedColumns()) + + table := stmt.Table() + if table == "" { + table = TableName(stmt.Entity()) + } + + where := stmt.Where() + if where != "" { + return "", errors.New("cannot use UpdateAllStatement() with where statement - use UpdateStatement() instead") + } + + var set []string + + for _, col := range columns { + set = append(set, fmt.Sprintf(`"%[1]s" = :%[1]s`, col)) + } + + return fmt.Sprintf( + `UPDATE "%s" SET %s`, + table, + set, + ), nil +} + +func (qb *queryBuilder) DeleteStatement(stmt DeleteStatement) (string, error) { + from := stmt.Table() + if from == "" { + from = TableName(stmt.Entity()) + } + where := stmt.Where() + if where != "" { + where = fmt.Sprintf(" WHERE %s", where) + } else { + return "", fmt.Errorf("%w: %s", ErrMissingStatementPart, "cannot use DeleteStatement() without where statement - use DeleteAllStatement() instead") + } + + return fmt.Sprintf( + `DELETE FROM "%s"%s`, + from, + where, + ), nil +} + +func (qb *queryBuilder) DeleteAllStatement(stmt DeleteStatement) (string, error) { + from := stmt.Table() + if from == "" { + from = TableName(stmt.Entity()) + } + where := stmt.Where() + if where != "" { + return "", errors.New("cannot use DeleteAllStatement() with where statement - use DeleteStatement() instead") + } + + return fmt.Sprintf( + `DELETE FROM "%s"`, + from, + ), nil +} + +func (qb *queryBuilder) BuildColumns(entity Entity, columns []string, excludedColumns []string) []string { + var entityColumns []string + + if len(columns) > 0 { + entityColumns = columns + } else { + tempColumns := qb.columnMap.Columns(entity) + entityColumns = make([]string, len(tempColumns)) + copy(entityColumns, tempColumns) + } + + if len(excludedColumns) > 0 { + entityColumns = slices.DeleteFunc( + entityColumns, + func(column string) bool { + return slices.Contains(excludedColumns, column) + }, + ) + } + + if qb.sort { + // The order in which the columns appear is not guaranteed as we extract the columns dynamically + // from the struct. So, we've to sort them here to be able to test the generated statements. + sort.Strings(entityColumns) + } + + return entityColumns[:len(entityColumns):len(entityColumns)] +} diff --git a/database/query_builder_test.go b/database/query_builder_test.go new file mode 100644 index 00000000..b7c92980 --- /dev/null +++ b/database/query_builder_test.go @@ -0,0 +1,638 @@ +package database + +import ( + "github.com/icinga/icinga-go-library/testutils" + "testing" +) + +type InsertStatementTestData struct { + Table string + Columns []string + ExcludedColumns []string +} + +type InsertIgnoreStatementTestData struct { + Driver string + Table string + Columns []string + ExcludedColumns []string +} + +type InsertSelectStatementTestData struct { + Table string + Columns []string + ExcludedColumns []string + Select SelectStatement +} + +type UpdateStatementTestData struct { + Table string + Columns []string + ExcludedColumns []string + Where string +} + +type UpsertStatementTestData struct { + Driver string + Table string + Columns []string + ExcludedColumns []string +} + +type DeleteStatementTestData struct { + Table string + Where string +} + +type DeleteAllStatementTestData struct { + Table string +} + +type SelectStatementTestData struct { + Table string + Columns []string + ExcludedColumns []string + Where string +} + +func TestInsertStatement(t *testing.T) { + tests := []testutils.TestCase[string, InsertStatementTestData]{ + { + Name: "NoColumnsSet", + Expected: `INSERT INTO "user" ("age", "email", "id", "name") VALUES (:age, :email, :id, :name)`, + }, + { + Name: "ColumnsSet", + Expected: `INSERT INTO "user" ("email", "id", "name") VALUES (:email, :id, :name)`, + Data: InsertStatementTestData{ + Columns: []string{"id", "name", "email"}, + }, + }, + { + Name: "ExcludedColumnsSet", + Expected: `INSERT INTO "user" ("age", "id", "name") VALUES (:age, :id, :name)`, + Data: InsertStatementTestData{ + ExcludedColumns: []string{"email"}, + }, + }, + { + Name: "ColumnsAndExcludedColumnsSet", + Expected: `INSERT INTO "user" ("id", "name") VALUES (:id, :name)`, + Data: InsertStatementTestData{ + Columns: []string{"id", "name", "email"}, + ExcludedColumns: []string{"email"}, + }, + }, + { + Name: "OverrideTableName", + Expected: `INSERT INTO "custom_table_name" ("email", "id", "name") VALUES (:email, :id, :name)`, + Data: InsertStatementTestData{ + Table: "custom_table_name", + Columns: []string{"id", "name", "email"}, + }, + }, + } + + for _, tst := range tests { + t.Run(tst.Name, tst.F(func(data InsertStatementTestData) (string, error) { + var actual string + var err error + + stmt := NewInsertStatement(&User{}). + SetColumns(data.Columns...). + SetExcludedColumns(data.ExcludedColumns...) + + if data.Table != "" { + stmt.Into(data.Table) + } + + qb := NewTestQueryBuilder(MySQL) + actual = qb.InsertStatement(stmt) + + return actual, err + + })) + } +} + +func TestInsertIgnoreStatement(t *testing.T) { + tests := []testutils.TestCase[string, InsertIgnoreStatementTestData]{ + { + Name: "NoColumnsSet_MySQL", + Expected: `INSERT IGNORE INTO "user" ("age", "email", "id", "name") VALUES (:age, :email, :id, :name)`, + Data: InsertIgnoreStatementTestData{ + Driver: MySQL, + }, + }, + { + Name: "ColumnsSet_MySQL", + Expected: `INSERT IGNORE INTO "user" ("email", "id", "name") VALUES (:email, :id, :name)`, + Data: InsertIgnoreStatementTestData{ + Driver: MySQL, + Columns: []string{"id", "name", "email"}, + }, + }, + { + Name: "ExcludedColumnsSet_MySQL", + Expected: `INSERT IGNORE INTO "user" ("age", "id", "name") VALUES (:age, :id, :name)`, + Data: InsertIgnoreStatementTestData{ + Driver: MySQL, + ExcludedColumns: []string{"email"}, + }, + }, + { + Name: "ColumnsAndExcludedColumnsSet_MySQL", + Expected: `INSERT IGNORE INTO "user" ("id", "name") VALUES (:id, :name)`, + Data: InsertIgnoreStatementTestData{ + Driver: MySQL, + Columns: []string{"id", "name", "email"}, + ExcludedColumns: []string{"email"}, + }, + }, + { + Name: "OverrideTableName_MySQL", + Expected: `INSERT IGNORE INTO "custom_table_name" ("email", "id", "name") VALUES (:email, :id, :name)`, + Data: InsertIgnoreStatementTestData{ + Driver: MySQL, + Table: "custom_table_name", + Columns: []string{"id", "name", "email"}, + }, + }, + { + Name: "NoColumnsSet_PostgreSQL", + Expected: `INSERT INTO "user" ("age", "email", "id", "name") VALUES (:age, :email, :id, :name) ON CONFLICT DO NOTHING`, + Data: InsertIgnoreStatementTestData{ + Driver: PostgreSQL, + }, + }, + { + Name: "ColumnsSet_PostgreSQL", + Expected: `INSERT INTO "user" ("email", "id", "name") VALUES (:email, :id, :name) ON CONFLICT DO NOTHING`, + Data: InsertIgnoreStatementTestData{ + Driver: PostgreSQL, + Columns: []string{"id", "name", "email"}, + }, + }, + { + Name: "ExcludedColumnsSet_PostgreSQL", + Expected: `INSERT INTO "user" ("age", "id", "name") VALUES (:age, :id, :name) ON CONFLICT DO NOTHING`, + Data: InsertIgnoreStatementTestData{ + Driver: PostgreSQL, + ExcludedColumns: []string{"email"}, + }, + }, + { + Name: "ColumnsAndExcludedColumnsSet_PostgreSQL", + Expected: `INSERT INTO "user" ("id", "name") VALUES (:id, :name) ON CONFLICT DO NOTHING`, + Data: InsertIgnoreStatementTestData{ + Driver: PostgreSQL, + Columns: []string{"id", "name", "email"}, + ExcludedColumns: []string{"email"}, + }, + }, + { + Name: "OverrideTableName_PostgreSQL", + Expected: `INSERT INTO "custom_table_name" ("email", "id", "name") VALUES (:email, :id, :name) ON CONFLICT DO NOTHING`, + Data: InsertIgnoreStatementTestData{ + Driver: PostgreSQL, + Table: "custom_table_name", + Columns: []string{"id", "name", "email"}, + ExcludedColumns: nil, + }, + }, + { + Name: "UnsupportedDriver", + Error: testutils.ErrorIs(ErrUnsupportedDriver), + Data: InsertIgnoreStatementTestData{ + Driver: "abcxyz", // Unsupported driver + Columns: []string{"id", "name", "email"}, + ExcludedColumns: nil, + }, + }, + } + + for _, tst := range tests { + t.Run(tst.Name, tst.F(func(data InsertIgnoreStatementTestData) (string, error) { + var actual string + var err error + + stmt := NewInsertStatement(&User{}). + SetColumns(data.Columns...). + SetExcludedColumns(data.ExcludedColumns...) + + if data.Table != "" { + stmt.Into(data.Table) + } + + qb := NewTestQueryBuilder(data.Driver) + actual, err = qb.InsertIgnoreStatement(stmt) + + return actual, err + + })) + } +} + +func TestInsertSelectStatement(t *testing.T) { + tests := []testutils.TestCase[string, InsertSelectStatementTestData]{ + { + Name: "ColumnsSet", + Expected: `INSERT INTO "user" ("email", "id", "name") SELECT "email", "id", "name" FROM "user" WHERE id = :id`, + Data: InsertSelectStatementTestData{ + Columns: []string{"id", "name", "email"}, + Select: NewSelectStatement(&User{}).SetColumns("id", "name", "email").SetWhere("id = :id"), + }, + }, + { + Name: "ExcludedColumnsSet", + Expected: `INSERT INTO "user" ("age", "id", "name") SELECT "age", "id", "name" FROM "user" WHERE id = :id`, + Data: InsertSelectStatementTestData{ + ExcludedColumns: []string{"email"}, + Select: NewSelectStatement(&User{}).SetExcludedColumns("email").SetWhere("id = :id"), + }, + }, + { + Name: "ColumnsAndExcludedColumnsSet", + Expected: `INSERT INTO "user" ("id", "name") SELECT "id", "name" FROM "user" WHERE id = :id`, + Data: InsertSelectStatementTestData{ + Columns: []string{"id", "name", "email"}, + ExcludedColumns: []string{"email"}, + Select: NewSelectStatement(&User{}).SetColumns("id", "name", "email").SetExcludedColumns("email").SetWhere("id = :id"), + }, + }, + { + Name: "OverrideTableName", + Expected: `INSERT INTO "custom_table_name" ("email", "id", "name") SELECT "email", "id", "name" FROM "user" WHERE id = :id`, + Data: InsertSelectStatementTestData{ + Table: "custom_table_name", + Columns: []string{"id", "name", "email"}, + Select: NewSelectStatement(&User{}).SetColumns("id", "name", "email").SetWhere("id = :id"), + }, + }, + { + Name: "SelectStatementMissing", + Error: testutils.ErrorIs(ErrMissingStatementPart), + Data: InsertSelectStatementTestData{}, + }, + //{ + // Name: "InvalidColumnName", + // Data: InsertStatementTestData{ + // Columns: []string{"id", "name", "email", "invalid_column"}, + // ExcludedColumns: nil, + // }, + // Error: testutils.ErrorIs(ErrInvalidColumnName), + //}, + } + + for _, tst := range tests { + t.Run(tst.Name, tst.F(func(data InsertSelectStatementTestData) (string, error) { + var actual string + var err error + + stmt := NewInsertSelectStatement(&User{}). + SetColumns(data.Columns...). + SetExcludedColumns(data.ExcludedColumns...) + + if data.Select != nil { + stmt.SetSelect(data.Select.(SelectStatement)) + } + + if data.Table != "" { + stmt.Into(data.Table) + } + + qb := NewTestQueryBuilder(MySQL) + actual, err = qb.InsertSelectStatement(stmt) + + return actual, err + + })) + } +} + +func TestUpdateStatement(t *testing.T) { + tests := []testutils.TestCase[string, UpdateStatementTestData]{ + { + Name: "NoWhereSet", + Error: testutils.ErrorIs(ErrMissingStatementPart), + }, + { + Name: "ColumnsSet", + Expected: `UPDATE "user" SET "email" = :email, "name" = :name WHERE id = :id`, + Data: UpdateStatementTestData{ + Columns: []string{"name", "email"}, + Where: "id = :id", + }, + }, + { + Name: "ExcludedColumnsSet", + Expected: `UPDATE "user" SET "email" = :email, "name" = :name WHERE id = :id`, + Data: UpdateStatementTestData{ + ExcludedColumns: []string{"id", "age"}, + Where: "id = :id", + }, + }, + { + Name: "OverrideTableName", + Expected: `UPDATE "custom_table_name" SET "email" = :email, "id" = :id, "name" = :name WHERE id = :id`, + Data: UpdateStatementTestData{ + Table: "custom_table_name", + Columns: []string{"id", "name", "email"}, + Where: "id = :id", + }, + }, + } + + for _, tst := range tests { + t.Run(tst.Name, tst.F(func(data UpdateStatementTestData) (string, error) { + var actual string + var err error + + stmt := NewUpdateStatement(&User{}). + SetColumns(data.Columns...). + SetExcludedColumns(data.ExcludedColumns...) + + if data.Where != "" { + stmt.SetWhere(data.Where) + } + + if data.Table != "" { + stmt.SetTable(data.Table) + } + + qb := NewTestQueryBuilder(MySQL) + actual, err = qb.UpdateStatement(stmt) + + return actual, err + + })) + } +} + +func TestUpsertStatement(t *testing.T) { + tests := []testutils.TestCase[string, UpsertStatementTestData]{ + { + Name: "NoColumnsSet_MySQL", + Expected: `INSERT INTO "user" ("age", "email", "id", "name") VALUES (:age, :email, :id, :name) ON DUPLICATE KEY UPDATE "age" = VALUES("age"), "email" = VALUES("email"), "id" = VALUES("id"), "name" = VALUES("name")`, + Data: UpsertStatementTestData{ + Driver: MySQL, + }, + }, + { + Name: "ColumnsSet_MySQL", + Expected: `INSERT INTO "user" ("email", "id", "name") VALUES (:email, :id, :name) ON DUPLICATE KEY UPDATE "email" = VALUES("email"), "id" = VALUES("id"), "name" = VALUES("name")`, + Data: UpsertStatementTestData{ + Driver: MySQL, + Columns: []string{"id", "name", "email"}, + }, + }, + { + Name: "ExcludedColumnsSet_MySQL", + Expected: `INSERT INTO "user" ("age", "id", "name") VALUES (:age, :id, :name) ON DUPLICATE KEY UPDATE "age" = VALUES("age"), "id" = VALUES("id"), "name" = VALUES("name")`, + Data: UpsertStatementTestData{ + Driver: MySQL, + ExcludedColumns: []string{"email"}, + }, + }, + { + Name: "ColumnsAndExcludedColumnsSet_MySQL", + Expected: `INSERT INTO "user" ("id", "name") VALUES (:id, :name) ON DUPLICATE KEY UPDATE "id" = VALUES("id"), "name" = VALUES("name")`, + Data: UpsertStatementTestData{ + Driver: MySQL, + Columns: []string{"id", "name", "email"}, + ExcludedColumns: []string{"email"}, + }, + }, + { + Name: "OverrideTableName_MySQL", + Expected: `INSERT INTO "custom_table_name" ("email", "id", "name") VALUES (:email, :id, :name) ON DUPLICATE KEY UPDATE "email" = VALUES("email"), "id" = VALUES("id"), "name" = VALUES("name")`, + Data: UpsertStatementTestData{ + Driver: MySQL, + Table: "custom_table_name", + Columns: []string{"id", "name", "email"}, + }, + }, + { + Name: "NoColumnsSet_PostgreSQL", + Expected: `INSERT INTO "user" ("age", "email", "id", "name") VALUES (:age, :email, :id, :name) ON CONFLICT ON CONSTRAINT pk_user DO UPDATE SET "age" = EXCLUDED."age", "email" = EXCLUDED."email", "id" = EXCLUDED."id", "name" = EXCLUDED."name"`, + Data: UpsertStatementTestData{ + Driver: PostgreSQL, + }, + }, + { + Name: "ColumnsSet_PostgreSQL", + Expected: `INSERT INTO "user" ("email", "id", "name") VALUES (:email, :id, :name) ON CONFLICT ON CONSTRAINT pk_user DO UPDATE SET "email" = EXCLUDED."email", "id" = EXCLUDED."id", "name" = EXCLUDED."name"`, + Data: UpsertStatementTestData{ + Driver: PostgreSQL, + Columns: []string{"id", "name", "email"}, + }, + }, + { + Name: "ExcludedColumnsSet_PostgreSQL", + Expected: `INSERT INTO "user" ("age", "id", "name") VALUES (:age, :id, :name) ON CONFLICT ON CONSTRAINT pk_user DO UPDATE SET "age" = EXCLUDED."age", "id" = EXCLUDED."id", "name" = EXCLUDED."name"`, + Data: UpsertStatementTestData{ + Driver: PostgreSQL, + ExcludedColumns: []string{"email"}, + }, + }, + { + Name: "ColumnsAndExcludedColumnsSet_PostgreSQL", + Expected: `INSERT INTO "user" ("id", "name") VALUES (:id, :name) ON CONFLICT ON CONSTRAINT pk_user DO UPDATE SET "id" = EXCLUDED."id", "name" = EXCLUDED."name"`, + Data: UpsertStatementTestData{ + Driver: PostgreSQL, + Columns: []string{"id", "name", "email"}, + ExcludedColumns: []string{"email"}, + }, + }, + { + Name: "OverrideTableName_PostgreSQL", + Expected: `INSERT INTO "custom_table_name" ("email", "id", "name") VALUES (:email, :id, :name) ON CONFLICT ON CONSTRAINT pk_custom_table_name DO UPDATE SET "email" = EXCLUDED."email", "id" = EXCLUDED."id", "name" = EXCLUDED."name"`, + Data: UpsertStatementTestData{ + Driver: PostgreSQL, + Table: "custom_table_name", + Columns: []string{"id", "name", "email"}, + }, + }, + } + + for _, tst := range tests { + t.Run(tst.Name, tst.F(func(data UpsertStatementTestData) (string, error) { + var actual string + var err error + + stmt := NewUpsertStatement(&User{}). + SetColumns(data.Columns...). + SetExcludedColumns(data.ExcludedColumns...) + + if data.Table != "" { + stmt.Into(data.Table) + } + + qb := NewTestQueryBuilder(data.Driver) + actual, _, err = qb.UpsertStatement(stmt) + + return actual, err + + })) + } +} + +func TestDeleteStatement(t *testing.T) { + tests := []testutils.TestCase[string, DeleteStatementTestData]{ + { + Name: "NoWhereSet", + Error: testutils.ErrorIs(ErrMissingStatementPart), + }, + { + Name: "WhereSet", + Expected: `DELETE FROM "user" WHERE id = :id`, + Data: DeleteStatementTestData{ + Where: "id = :id", + }, + }, + { + Name: "OverrideTableName", + Expected: `DELETE FROM "custom_table_name" WHERE id = :id`, + Data: DeleteStatementTestData{ + Table: "custom_table_name", + Where: "id = :id", + }, + }, + } + + for _, tst := range tests { + t.Run(tst.Name, tst.F(func(data DeleteStatementTestData) (string, error) { + var actual string + var err error + + stmt := NewDeleteStatement(&User{}) + + if data.Where != "" { + stmt.SetWhere(data.Where) + } + + if data.Table != "" { + stmt.From(data.Table) + } + + qb := NewTestQueryBuilder(MySQL) + actual, err = qb.DeleteStatement(stmt) + + return actual, err + + })) + } +} + +func TestDeleteAllStatement(t *testing.T) { + tests := []testutils.TestCase[string, DeleteAllStatementTestData]{ + { + Name: "AutoTableName", + Expected: `DELETE FROM "user"`, + }, + { + Name: "OverrideTableName", + Expected: `DELETE FROM "custom_table_name"`, + Data: DeleteAllStatementTestData{ + Table: "custom_table_name", + }, + }, + } + + for _, tst := range tests { + t.Run(tst.Name, tst.F(func(data DeleteAllStatementTestData) (string, error) { + var actual string + var err error + + stmt := NewDeleteStatement(&User{}) + + if data.Table != "" { + stmt.From(data.Table) + } + + qb := NewTestQueryBuilder(MySQL) + actual, err = qb.DeleteAllStatement(stmt) + + return actual, err + + })) + } +} + +func TestSelectStatement(t *testing.T) { + tests := []testutils.TestCase[string, SelectStatementTestData]{ + { + Name: "NoColumnsSet", + Expected: `SELECT "age", "email", "id", "name" FROM "user"`, + }, + { + Name: "ColumnsSet", + Expected: `SELECT "email", "id", "name" FROM "user"`, + Data: SelectStatementTestData{ + Columns: []string{"id", "name", "email"}, + }, + }, + { + Name: "ExcludedColumnsSet", + Expected: `SELECT "age", "id", "name" FROM "user"`, + Data: SelectStatementTestData{ + ExcludedColumns: []string{"email"}, + }, + }, + { + Name: "ColumnsAndExcludedColumnsSet", + Expected: `SELECT "id", "name" FROM "user"`, + Data: SelectStatementTestData{ + Columns: []string{"id", "name", "email"}, + ExcludedColumns: []string{"email"}, + }, + }, + { + Name: "OverrideTableName", + Expected: `SELECT "email", "id", "name" FROM "custom_table_name"`, + Data: SelectStatementTestData{ + Table: "custom_table_name", + Columns: []string{"id", "name", "email"}, + }, + }, + { + Name: "WhereSet", + Expected: `SELECT "age", "email", "id", "name" FROM "user" WHERE id = :id`, + Data: SelectStatementTestData{ + Where: "id = :id", + }, + }, + { + Name: "MultipleConditionsWhereSet", + Expected: `SELECT "age", "email", "id", "name" FROM "user" WHERE id = :id AND name = :name AND email = :email`, + Data: SelectStatementTestData{ + Where: "id = :id AND name = :name AND email = :email", + }, + }, + } + + for _, tst := range tests { + t.Run(tst.Name, tst.F(func(data SelectStatementTestData) (string, error) { + var actual string + var err error + + stmt := NewSelectStatement(&User{}). + SetColumns(data.Columns...). + SetExcludedColumns(data.ExcludedColumns...) + + if data.Table != "" { + stmt.From(data.Table) + } + + if data.Where != "" { + stmt.SetWhere(data.Where) + } + + qb := NewTestQueryBuilder(MySQL) + actual = qb.SelectStatement(stmt) + + return actual, err + + })) + } +} diff --git a/database/select.go b/database/select.go new file mode 100644 index 00000000..9ced365e --- /dev/null +++ b/database/select.go @@ -0,0 +1,93 @@ +package database + +// SelectStatement is the interface for building SELECT statements. +type SelectStatement interface { + // From sets the table name for the SELECT statement. + // Overrides the table name provided by the entity. + From(table string) SelectStatement + + // SetColumns sets the columns to be selected. + SetColumns(columns ...string) SelectStatement + + // SetExcludedColumns sets the columns to be excluded from the SELECT statement. + // Excludes also columns set by SetColumns. + SetExcludedColumns(columns ...string) SelectStatement + + // SetWhere sets the where clause for the SELECT statement. + SetWhere(where string) SelectStatement + + // Entity returns the entity associated with the SELECT statement. + Entity() Entity + + // Table returns the table name for the SELECT statement. + Table() string + + // Columns returns the columns to be selected. + Columns() []string + + // ExcludedColumns returns the columns to be excluded from the SELECT statement. + ExcludedColumns() []string + + // Where returns the where clause for the SELECT statement. + Where() string +} + +// NewSelectStatement returns a new selectStatement for the given entity. +func NewSelectStatement(entity Entity) SelectStatement { + return &selectStatement{ + entity: entity, + } +} + +// selectStatement is the default implementation of the SelectStatement interface. +type selectStatement struct { + entity Entity + table string + columns []string + excludedColumns []string + where string +} + +func (s *selectStatement) From(table string) SelectStatement { + s.table = table + + return s +} + +func (s *selectStatement) SetColumns(columns ...string) SelectStatement { + s.columns = columns + + return s +} + +func (s *selectStatement) SetExcludedColumns(columns ...string) SelectStatement { + s.excludedColumns = columns + + return s +} + +func (s *selectStatement) SetWhere(where string) SelectStatement { + s.where = where + + return s +} + +func (s *selectStatement) Entity() Entity { + return s.entity +} + +func (s *selectStatement) Table() string { + return s.table +} + +func (s *selectStatement) Columns() []string { + return s.columns +} + +func (s *selectStatement) ExcludedColumns() []string { + return s.excludedColumns +} + +func (s *selectStatement) Where() string { + return s.where +} diff --git a/database/testutils.go b/database/testutils.go new file mode 100644 index 00000000..d8bb830b --- /dev/null +++ b/database/testutils.go @@ -0,0 +1,95 @@ +package database + +import ( + "fmt" + "github.com/creasty/defaults" + "github.com/icinga/icinga-go-library/logging" + "github.com/icinga/icinga-go-library/utils" + "go.uber.org/zap/zapcore" + "math/rand" + "strconv" + "time" +) + +type User struct { + Id Id + Name string + Age int + Email string +} + +type Id int + +func (i Id) String() string { + return strconv.Itoa(int(i)) +} + +func (m User) ID() ID { + return m.Id +} + +func (m User) SetID(id ID) { + m.Id = id.(Id) +} + +func (m User) Fingerprint() Fingerprinter { + return m +} + +func getTestLogging() *logging.Logging { + logs, err := logging.NewLoggingFromConfig( + "Icinga Go Library", + logging.Config{Level: zapcore.DebugLevel, Output: "console", Interval: time.Second * 10}, + ) + if err != nil { + utils.PrintErrorThenExit(err, 1) + } + + return logs +} + +func getTestDb(logs *logging.Logging) *DB { + var defaultOptions Options + + if err := defaults.Set(&defaultOptions); err != nil { + utils.PrintErrorThenExit(err, 1) + } + + randomName := strconv.Itoa(rand.Int()) + + db, err := NewDbFromConfig( + &Config{Type: "sqlite", Database: fmt.Sprintf(":memory:%s", randomName), Options: defaultOptions}, + logs.GetChildLogger("database"), + RetryConnectorCallbacks{}, + ) + if err != nil { + utils.PrintErrorThenExit(err, 1) + } + + return db +} + +func initTestDb(db *DB) { + if _, err := db.Query("DROP TABLE IF EXISTS user"); err != nil { + utils.PrintErrorThenExit(err, 1) + } + + if _, err := db.Query(`CREATE TABLE user ("id" INTEGER PRIMARY KEY, "name" VARCHAR(255) DEFAULT '', "age" INTEGER DEFAULT 0, "email" VARCHAR(255) DEFAULT '')`); err != nil { + utils.PrintErrorThenExit(err, 1) + } +} + +func prefillTestDb(db *DB) { + entities := []User{ + {Id: 1, Name: "Alice Johnson", Age: 25, Email: "alice.johnson@example.com"}, + {Id: 2, Name: "Bob Smith", Age: 30, Email: "bob.smith@example.com"}, + {Id: 3, Name: "Charlie Brown", Age: 22, Email: "charlie.brown@example.com"}, + {Id: 4, Name: "Diana Prince", Age: 28, Email: "diana.prince@example.com"}, + } + + for _, entity := range entities { + if _, err := db.NamedExec(`INSERT INTO user ("id", "name", "age", "email") VALUES (:id, :name, :age, :email)`, entity); err != nil { + utils.PrintErrorThenExit(err, 1) + } + } +} diff --git a/database/update.go b/database/update.go new file mode 100644 index 00000000..d6a6946f --- /dev/null +++ b/database/update.go @@ -0,0 +1,128 @@ +package database + +import "context" + +// UpdateStatement is the interface for building UPDATE statements. +type UpdateStatement interface { + // SetTable sets the table name for the UPDATE statement. + // Overrides the table name provided by the entity. + SetTable(table string) UpdateStatement + + // SetColumns sets the columns to be updated. + SetColumns(columns ...string) UpdateStatement + + // SetExcludedColumns sets the columns to be excluded from the UPDATE statement. + // Excludes also columns set by SetColumns. + SetExcludedColumns(columns ...string) UpdateStatement + + // SetWhere sets the where clause for the UPDATE statement. + SetWhere(where string) UpdateStatement + + // Entity returns the entity associated with the UPDATE statement. + Entity() Entity + + // Table returns the table name for the UPDATE statement. + Table() string + + // Columns returns the columns to be updated. + Columns() []string + + // ExcludedColumns returns the columns to be excluded from the UPDATE statement. + ExcludedColumns() []string + + // Where returns the where clause for the UPDATE statement. + Where() string +} + +// NewUpdateStatement returns a new updateStatement for the given entity. +func NewUpdateStatement(entity Entity) UpdateStatement { + return &updateStatement{ + entity: entity, + } +} + +// updateStatement is the default implementation of the UpdateStatement interface. +type updateStatement struct { + entity Entity + table string + columns []string + excludedColumns []string + where string +} + +func (u *updateStatement) SetTable(table string) UpdateStatement { + u.table = table + + return u +} + +func (u *updateStatement) SetColumns(columns ...string) UpdateStatement { + u.columns = columns + + return u +} + +func (u *updateStatement) SetExcludedColumns(columns ...string) UpdateStatement { + u.excludedColumns = columns + + return u +} + +func (u *updateStatement) SetWhere(where string) UpdateStatement { + u.where = where + + return u +} + +func (u *updateStatement) Entity() Entity { + return u.entity +} + +func (u *updateStatement) Table() string { + return u.table +} + +func (u *updateStatement) Columns() []string { + return u.columns +} + +func (u *updateStatement) ExcludedColumns() []string { + return u.excludedColumns +} + +func (u *updateStatement) Where() string { + return u.where +} + +// UpdateOption is a functional option for UpdateStreamed(). +type UpdateOption func(opts *updateOptions) + +// WithUpdateStatement sets the UPDATE statement to be used for updating entities. +func WithUpdateStatement(stmt UpdateStatement) UpdateOption { + return func(opts *updateOptions) { + opts.stmt = stmt + } +} + +// WithOnUpdate sets the callback functions to be called after a successful UPDATE. +func WithOnUpdate(onUpdate ...OnSuccess[any]) UpdateOption { + return func(opts *updateOptions) { + opts.onUpdate = append(opts.onUpdate, onUpdate...) + } +} + +// updateOptions stores the options for UpdateStreamed. +type updateOptions struct { + stmt UpdateStatement + onUpdate []OnSuccess[any] +} + +func UpdateStreamed[T any, V EntityConstraint[T]]( + ctx context.Context, + db *DB, + entities <-chan T, + options ...UpdateOption, +) error { + // TODO (jr): implement + return nil +} diff --git a/database/upsert.go b/database/upsert.go new file mode 100644 index 00000000..b55f9316 --- /dev/null +++ b/database/upsert.go @@ -0,0 +1,242 @@ +package database + +import ( + "context" + "github.com/icinga/icinga-go-library/backoff" + "github.com/icinga/icinga-go-library/com" + "github.com/icinga/icinga-go-library/retry" + "github.com/pkg/errors" + "golang.org/x/sync/errgroup" + "golang.org/x/sync/semaphore" + "time" +) + +// UpsertStatement is the interface for building UPSERT statements. +type UpsertStatement interface { + // Into sets the table name for the UPSERT statement. + // Overrides the table name provided by the entity. + Into(table string) UpsertStatement + + // SetColumns sets the columns to be inserted or updated. + SetColumns(columns ...string) UpsertStatement + + // SetExcludedColumns sets the columns to be excluded from the UPSERT statement. + // Excludes also columns set by SetColumns. + SetExcludedColumns(columns ...string) UpsertStatement + + // Entity returns the entity associated with the UPSERT statement. + Entity() Entity + + // Table returns the table name for the UPSERT statement. + Table() string + + // Columns returns the columns to be inserted or updated. + Columns() []string + + // ExcludedColumns returns the columns to be excluded from the UPSERT statement. + ExcludedColumns() []string +} + +// NewUpsertStatement returns a new upsertStatement for the given entity. +func NewUpsertStatement(entity Entity) UpsertStatement { + return &upsertStatement{ + entity: entity, + } +} + +// upsertStatement is the default implementation of the UpsertStatement interface. +type upsertStatement struct { + entity Entity + table string + columns []string + excludedColumns []string +} + +func (u *upsertStatement) Into(table string) UpsertStatement { + u.table = table + + return u +} + +func (u *upsertStatement) SetColumns(columns ...string) UpsertStatement { + u.columns = columns + + return u +} + +func (u *upsertStatement) SetExcludedColumns(columns ...string) UpsertStatement { + u.excludedColumns = columns + + return u +} + +func (u *upsertStatement) Entity() Entity { + return u.entity +} + +func (u *upsertStatement) Table() string { + return u.table +} + +func (u *upsertStatement) Columns() []string { + return u.columns +} + +func (u *upsertStatement) ExcludedColumns() []string { + return u.excludedColumns +} + +// UpsertOption is a functional option for UpsertStreamed(). +type UpsertOption func(opts *upsertOptions) + +// WithUpsertStatement sets the UPSERT statement to be used for upserting entities. +func WithUpsertStatement(stmt UpsertStatement) UpsertOption { + return func(opts *upsertOptions) { + opts.stmt = stmt + } +} + +// WithOnUpsert sets the callback functions to be called after a successful UPSERT. +func WithOnUpsert(onUpsert ...OnSuccess[any]) UpsertOption { + return func(opts *upsertOptions) { + opts.onUpsert = append(opts.onUpsert, onUpsert...) + } +} + +// upsertOptions stores the options for UpsertStreamed. +type upsertOptions struct { + stmt UpsertStatement + onUpsert []OnSuccess[any] +} + +// UpsertStreamed upserts entities from the given channel into the database. +func UpsertStreamed[T any, V EntityConstraint[T]]( + ctx context.Context, + db *DB, + entities <-chan T, + options ...UpsertOption, +) error { + var ( + opts = &upsertOptions{} + entityType = V(new(T)) + sem = db.GetSemaphoreForTable(TableName(entityType)) + stmt string + placeholders int + err error + ) + + for _, option := range options { + option(opts) + } + + if opts.stmt != nil { + stmt, placeholders, err = db.QueryBuilder().UpsertStatement(opts.stmt) + if err != nil { + return err + } + } else { + stmt, placeholders, err = db.QueryBuilder().UpsertStatement(NewUpsertStatement(entityType)) + if err != nil { + return err + } + } + + return namedBulkExec[T]( + ctx, db, stmt, db.BatchSizeByPlaceholders(placeholders), sem, + entities, splitOnDupId[T], opts.onUpsert..., + ) +} + +func namedBulkExec[T any]( + ctx context.Context, + db *DB, + query string, + count int, + sem *semaphore.Weighted, + arg <-chan T, + splitPolicyFactory com.BulkChunkSplitPolicyFactory[T], + onSuccess ...OnSuccess[any], +) error { + var counter com.Counter + defer db.Log(ctx, query, &counter).Stop() + + g, ctx := errgroup.WithContext(ctx) + bulk := com.Bulk(ctx, arg, count, splitPolicyFactory) + + g.Go(func() error { + for { + select { + case b, ok := <-bulk: + if !ok { + return nil + } + + if err := sem.Acquire(ctx, 1); err != nil { + return errors.Wrap(err, "can't acquire semaphore") + } + + g.Go(func(b []T) func() error { + return func() error { + defer sem.Release(1) + + return retry.WithBackoff( + ctx, + func(ctx context.Context) error { + _, err := db.NamedExecContext(ctx, query, b) + if err != nil { + return CantPerformQuery(err, query) + } + + counter.Add(uint64(len(b))) + + for _, onSuccess := range onSuccess { + // TODO (jr): remove -> workaround vvvv + var items []any + for _, item := range b { + items = append(items, any(item)) + } + // TODO ---- workaround end ---- ^^^^ + + if err := onSuccess(ctx, items); err != nil { + return err + } + } + + return nil + }, + retry.Retryable, + backoff.NewExponentialWithJitter(1*time.Millisecond, 1*time.Second), + db.GetDefaultRetrySettings(), + ) + } + }(b)) + case <-ctx.Done(): + return ctx.Err() + } + } + }) + + return g.Wait() +} + +func splitOnDupId[T any]() com.BulkChunkSplitPolicy[T] { + seenIds := map[string]struct{}{} + + return func(ider T) bool { + entity, ok := any(ider).(IDer) + if !ok { + panic("Type T does not implement IDer") + } + + id := entity.ID().String() + + _, ok = seenIds[id] + if ok { + seenIds = map[string]struct{}{id: {}} + } else { + seenIds[id] = struct{}{} + } + + return ok + } +} diff --git a/database/upsert_test.go b/database/upsert_test.go new file mode 100644 index 00000000..3dada2b4 --- /dev/null +++ b/database/upsert_test.go @@ -0,0 +1,163 @@ +package database + +import ( + "context" + "github.com/icinga/icinga-go-library/testutils" + "testing" + "time" +) + +type UpsertStreamedTestData struct { + Entities []User + Statement UpsertStatement + Callbacks []OnSuccess[User] +} + +func TestUpsertStreamed(t *testing.T) { + tests := []testutils.TestCase[[]User, UpsertStreamedTestData]{ + { + Name: "Insert", + Expected: []User{ + {Id: 1, Name: "Alice Johnson", Age: 25, Email: "alice.johnson@example.com"}, + {Id: 2, Name: "Bob Smith", Age: 30, Email: "bob.smith@example.com"}, + {Id: 3, Name: "Charlie Brown", Age: 22, Email: "charlie.brown@example.com"}, + {Id: 4, Name: "Diana Prince", Age: 28, Email: "diana.prince@example.com"}, + {Id: 5, Name: "Evan Davis", Age: 35, Email: "evan.davis@example.com"}, + {Id: 6, Name: "Fiona White", Age: 27, Email: "fiona.white@example.com"}, + {Id: 7, Name: "George King", Age: 29, Email: "george.king@example.com"}, + {Id: 8, Name: "Hannah Moore", Age: 31, Email: "hannah.moore@example.com"}, + }, + Data: UpsertStreamedTestData{ + Entities: []User{ + {Id: 5, Name: "Evan Davis", Age: 35, Email: "evan.davis@example.com"}, + {Id: 6, Name: "Fiona White", Age: 27, Email: "fiona.white@example.com"}, + {Id: 7, Name: "George King", Age: 29, Email: "george.king@example.com"}, + {Id: 8, Name: "Hannah Moore", Age: 31, Email: "hannah.moore@example.com"}, + }, + }, + }, + { + Name: "Update", + Expected: []User{ + {Id: 1, Name: "Alice Johnson", Age: 25, Email: "alice.johnson@example.com"}, + {Id: 2, Name: "Bob Smith", Age: 30, Email: "bob.smith@example.com"}, + {Id: 3, Name: "Evan Davis", Age: 35, Email: "evan.davis@example.com"}, + {Id: 4, Name: "Fiona White", Age: 27, Email: "fiona.white@example.com"}, + }, + Data: UpsertStreamedTestData{ + Entities: []User{ + {Id: 3, Name: "Evan Davis", Age: 35, Email: "evan.davis@example.com"}, + {Id: 4, Name: "Fiona White", Age: 27, Email: "fiona.white@example.com"}, + }, + }, + }, + { + Name: "InsertAndUpdate", + Expected: []User{ + {Id: 1, Name: "Alice Johnson", Age: 25, Email: "alice.johnson@example.com"}, + {Id: 2, Name: "Bob Smith", Age: 30, Email: "bob.smith@example.com"}, + {Id: 3, Name: "Charlie Brown", Age: 22, Email: "charlie.brown@example.com"}, + {Id: 4, Name: "George King", Age: 29, Email: "george.king@example.com"}, + {Id: 5, Name: "Hannah Moore", Age: 31, Email: "hannah.moore@example.com"}, + {Id: 6, Name: "Fiona White", Age: 27, Email: "fiona.white@example.com"}, + }, + Data: UpsertStreamedTestData{ + Entities: []User{ + {Id: 5, Name: "Evan Davis", Age: 35, Email: "evan.davis@example.com"}, + {Id: 6, Name: "Fiona White", Age: 27, Email: "fiona.white@example.com"}, + {Id: 4, Name: "George King", Age: 29, Email: "george.king@example.com"}, + {Id: 5, Name: "Hannah Moore", Age: 31, Email: "hannah.moore@example.com"}, + }, + }, + }, + { + Name: "WithStatement", + Expected: []User{ + {Id: 1, Name: "Alice Johnson", Age: 25, Email: "alice.johnson@example.com"}, + {Id: 2, Name: "Bob Smith", Age: 30, Email: "bob.smith@example.com"}, + {Id: 3, Name: "Charlie Brown", Age: 22, Email: "charlie.brown@example.com"}, + {Id: 4, Name: "Diana Prince", Age: 28, Email: "diana.prince@example.com"}, + {Id: 5, Name: "Evan Davis", Age: 35, Email: "evan.davis@example.com"}, + {Id: 6, Name: "Fiona White", Age: 27, Email: "fiona.white@example.com"}, + }, + Data: UpsertStreamedTestData{ + Entities: []User{ + {Id: 5, Name: "Evan Davis", Age: 35, Email: "evan.davis@example.com"}, + {Id: 6, Name: "Fiona White", Age: 27, Email: "fiona.white@example.com"}, + }, + Statement: NewUpsertStatement(&User{}), + }, + }, + { + Name: "WithFalseStatement", + Error: testutils.ErrorContains("can't perform"), // TODO (jr): is it the right way? + Data: UpsertStreamedTestData{ + Entities: []User{ + {Id: 5, Name: "test5", Age: 50, Email: "test5@test.com"}, + }, + Statement: NewUpsertStatement(&User{}).Into("false_table"), + }, + }, + } + + for _, tst := range tests { + t.Run(tst.Name, tst.F(func(data UpsertStreamedTestData) ([]User, error) { + var ( + upsertError error + ctx, cancel = context.WithCancel(context.Background()) + entities = make(chan User) + logs = getTestLogging() + db = getTestDb(logs) + ) + + go func() { + if tst.Data.Statement != nil { + upsertError = UpsertStreamed(ctx, db, entities, WithUpsertStatement(tst.Data.Statement)) + } else { + upsertError = UpsertStreamed(ctx, db, entities) + } + }() + + initTestDb(db) + prefillTestDb(db) + + for _, entity := range tst.Data.Entities { + entities <- entity + } + + var actual []User + + time.Sleep(time.Second) + + if err := db.Select(&actual, "SELECT * FROM user"); err != nil { + t.Fatalf("cannot select from database: %v", err) + } + + cancel() + _ = db.Close() + + return actual, upsertError + })) + } +} + +// TODO (jr) +//func TestUpsertStreamedCallback(t *testing.T) { +// tests := []testutils.TestCase[any, UpsertStreamedTestData]{ +// { +// Name: "OneCallback", +// Data: UpsertStreamedTestData{ +// Callbacks: []OnSuccess[User]{ +// func(ctx context.Context, affectedRows []User) error { +// +// }, +// }, +// }, +// }, +// } +//} + +// TODO (jr) +// func TestUpsertStreamedEarlyDbClose(t *testing.T) { +// +// } diff --git a/go.mod b/go.mod index 24d88353..3f01c6b7 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,7 @@ require ( go.uber.org/zap v1.27.0 golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6 golang.org/x/sync v0.14.0 + modernc.org/sqlite v1.34.2 ) require ( @@ -26,8 +27,19 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/ncruces/go-strftime v0.1.9 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect go.uber.org/multierr v1.10.0 // indirect - golang.org/x/sys v0.21.0 // indirect + golang.org/x/sys v0.22.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect + modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 // indirect + modernc.org/libc v1.55.3 // indirect + modernc.org/mathutil v1.6.0 // indirect + modernc.org/memory v1.8.0 // indirect + modernc.org/strutil v1.2.0 // indirect + modernc.org/token v1.1.0 // indirect ) diff --git a/go.sum b/go.sum index 1919342c..3988c575 100644 --- a/go.sum +++ b/go.sum @@ -14,28 +14,40 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= github.com/go-sql-driver/mysql v1.9.2 h1:4cNKDYQ1I84SXslGddlsrMhc8k4LeDVj6Ad6WRjiHuU= github.com/go-sql-driver/mysql v1.9.2/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= github.com/goccy/go-yaml v1.17.1 h1:LI34wktB2xEE3ONG/2Ar54+/HJVBriAGJ55PHls4YuY= github.com/goccy/go-yaml v1.17.1/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= +github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlGkMFWCjLFlqqEZjEmObmhUy6Vo= +github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/jessevdk/go-flags v1.6.1 h1:Cvu5U8UGrLay1rZfv/zP7iLpSHGUZ/Ou68T0iX1bBK4= github.com/jessevdk/go-flags v1.6.1/go.mod h1:Mk8T1hIAWpOiJiHa9rJASDK2UGWji0EuPGBnNLMooyc= github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o= github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mattn/go-sqlite3 v1.14.28 h1:ThEiQrnbtumT+QMknw63Befp/ce/nUPgBPMlRFEum7A= github.com/mattn/go-sqlite3 v1.14.28/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= +github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/redis/go-redis/v9 v9.8.0 h1:q3nRvjrlge/6UD7eTu/DSg2uYiU2mCL0G/uzBWqhicI= github.com/redis/go-redis/v9 v9.8.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/ssgreg/journald v1.0.0 h1:0YmTDPJXxcWDPba12qNMdO6TxvfkFSYpFIJ31CwmLcU= github.com/ssgreg/journald v1.0.0/go.mod h1:RUckwmTM8ghGWPslq2+ZBZzbb9/2KgjzYZ4JEP+oRt0= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= @@ -48,11 +60,42 @@ go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6 h1:y5zboxd6LQAqYIhHnB48p0ByQ/GnQx2BE33L8BOHQkI= golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6/go.mod h1:U6Lno4MTRCDY+Ba7aCcauB9T60gsv5s4ralQzP72ZoQ= +golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU= +golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ= golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= -golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= -golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= +golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc= +golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +modernc.org/cc/v4 v4.21.4 h1:3Be/Rdo1fpr8GrQ7IVw9OHtplU4gWbb+wNgeoBMmGLQ= +modernc.org/cc/v4 v4.21.4/go.mod h1:HM7VJTZbUCR3rV8EYBi9wxnJ0ZBRiGE5OeGXNA0IsLQ= +modernc.org/ccgo/v4 v4.19.2 h1:lwQZgvboKD0jBwdaeVCTouxhxAyN6iawF3STraAal8Y= +modernc.org/ccgo/v4 v4.19.2/go.mod h1:ysS3mxiMV38XGRTTcgo0DQTeTmAO4oCmJl1nX9VFI3s= +modernc.org/fileutil v1.3.0 h1:gQ5SIzK3H9kdfai/5x41oQiKValumqNTDXMvKo62HvE= +modernc.org/fileutil v1.3.0/go.mod h1:XatxS8fZi3pS8/hKG2GH/ArUogfxjpEKs3Ku3aK4JyQ= +modernc.org/gc/v2 v2.4.1 h1:9cNzOqPyMJBvrUipmynX0ZohMhcxPtMccYgGOJdOiBw= +modernc.org/gc/v2 v2.4.1/go.mod h1:wzN5dK1AzVGoH6XOzc3YZ+ey/jPgYHLuVckd62P0GYU= +modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 h1:5D53IMaUuA5InSeMu9eJtlQXS2NxAhyWQvkKEgXZhHI= +modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6/go.mod h1:Qz0X07sNOR1jWYCrJMEnbW/X55x206Q7Vt4mz6/wHp4= +modernc.org/libc v1.55.3 h1:AzcW1mhlPNrRtjS5sS+eW2ISCgSOLLNyFzRh/V3Qj/U= +modernc.org/libc v1.55.3/go.mod h1:qFXepLhz+JjFThQ4kzwzOjA/y/artDeg+pcYnY+Q83w= +modernc.org/mathutil v1.6.0 h1:fRe9+AmYlaej+64JsEEhoWuAYBkOtQiMEU7n/XgfYi4= +modernc.org/mathutil v1.6.0/go.mod h1:Ui5Q9q1TR2gFm0AQRqQUaBWFLAhQpCwNcuhBOSedWPo= +modernc.org/memory v1.8.0 h1:IqGTL6eFMaDZZhEWwcREgeMXYwmW83LYW8cROZYkg+E= +modernc.org/memory v1.8.0/go.mod h1:XPZ936zp5OMKGWPqbD3JShgd/ZoQ7899TUuQqxY+peU= +modernc.org/opt v0.1.3 h1:3XOZf2yznlhC+ibLltsDGzABUGVx8J6pnFMS3E4dcq4= +modernc.org/opt v0.1.3/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0= +modernc.org/sortutil v1.2.0 h1:jQiD3PfS2REGJNzNCMMaLSp/wdMNieTbKX920Cqdgqc= +modernc.org/sortutil v1.2.0/go.mod h1:TKU2s7kJMf1AE84OoiGppNHJwvB753OYfNl2WRb++Ss= +modernc.org/sqlite v1.34.2 h1:J9n76TPsfYYkFkZ9Uy1QphILYifiVEwwOT7yP5b++2Y= +modernc.org/sqlite v1.34.2/go.mod h1:dnR723UrTtjKpoHCAMN0Q/gZ9MT4r+iRvIBb9umWFkU= +modernc.org/strutil v1.2.0 h1:agBi9dp1I+eOnxXeiZawM8F4LawKv4NzGWSaLfyeNZA= +modernc.org/strutil v1.2.0/go.mod h1:/mdcBmfOibveCTBxUl5B5l6W+TTH1FXPLHZE6bTosX0= +modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= +modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= diff --git a/retry/retry.go b/retry/retry.go index fc1648cf..6bd7fde8 100644 --- a/retry/retry.go +++ b/retry/retry.go @@ -3,6 +3,7 @@ package retry import ( "context" "database/sql/driver" + stderrors "errors" "github.com/go-sql-driver/mysql" "github.com/icinga/icinga-go-library/backoff" "github.com/lib/pq" @@ -17,6 +18,8 @@ import ( // DefaultTimeout is our opinionated default timeout for retrying database and Redis operations. const DefaultTimeout = 5 * time.Minute +var ErrNotRetryable = stderrors.New("error not retryable") + // RetryableFunc is a retryable function. type RetryableFunc func(context.Context) error @@ -137,6 +140,10 @@ func ResetTimeout(t *time.Timer, d time.Duration) { // i.e. temporary, timeout, DNS, connection refused and reset, host down and unreachable and // network down and unreachable errors. In addition, any database error is considered retryable. func Retryable(err error) bool { + if errors.Is(err, ErrNotRetryable) { + return false + } + var temporary interface { Temporary() bool }