diff --git a/integration/transaction_test.go b/integration/transaction_test.go index 1f4367c..371f013 100644 --- a/integration/transaction_test.go +++ b/integration/transaction_test.go @@ -53,19 +53,19 @@ func TestExecuteTransactionError(t *testing.T) { return err } - // sleeping for timeout - time.Sleep(time.Hour) + // wait until context expires + <-ctx.Done() // Patching the model tenantID := "" _, err = r.Patch(ctx, &model.System{ - ExternalID: expSys1.ExternalID, - TenantID: &tenantID, + ID: expSys1.ID, + TenantID: &tenantID, }) return err }) // then - assert.Equal(t, context.DeadlineExceeded, err) + assert.ErrorIs(t, err, context.DeadlineExceeded) }) t.Run("should able to do transaction for the same row after a error in first transaction", func(t *testing.T) { diff --git a/internal/repository/sql/postgres.go b/internal/repository/sql/postgres.go index 3753342..3da78ea 100644 --- a/internal/repository/sql/postgres.go +++ b/internal/repository/sql/postgres.go @@ -22,7 +22,7 @@ func StartDB(ctx context.Context, dbConf config.DB) (*gorm.DB, error) { return nil, err } - dbCon.WithContext(ctx) + dbCon = dbCon.WithContext(ctx) slog.Info("DB connection done") if err = Migrate(dbCon); err != nil { diff --git a/internal/repository/sql/resource_repository.go b/internal/repository/sql/resource_repository.go index 2276c0f..bcc706f 100644 --- a/internal/repository/sql/resource_repository.go +++ b/internal/repository/sql/resource_repository.go @@ -130,25 +130,11 @@ func (r ResourceRepository) PatchAll(ctx context.Context, resource repository.Re return db.RowsAffected, nil } -// Transaction will give transaction locking on particular rows. -// txFunc is a type TransactionFunc where we can define the transactional logic. -// if txFunc return no error then transaction is committed, -// else if txFunc return error then transaction is rolled back. -// Note: please dont use Goroutines inside the txFunc as this might lead to panic. +// Transaction executes txFunc inside a GORM transaction with SELECT FOR UPDATE locking. +// Commits on nil return, rolls back on error. func (r ResourceRepository) Transaction(ctx context.Context, txFunc repository.TransactionFunc) error { return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { - errorChan := make(chan error) - - go func() { - errorChan <- txFunc(ctx, NewRepository(tx.Clauses(clause.Locking{Strength: "UPDATE"}))) - }() - - select { - case <-ctx.Done(): - return ctx.Err() - case err := <-errorChan: - return err - } + return txFunc(ctx, NewRepository(tx.Clauses(clause.Locking{Strength: "UPDATE"}))) }) }