Skip to content

Commit

Permalink
Implement ExecuteSqls() for each test container type
Browse files Browse the repository at this point in the history
  • Loading branch information
sanyamsinghal committed Dec 20, 2024
1 parent d75762d commit 788ea74
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 3 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ on:

jobs:

build:
build-and-test:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v3
Expand Down
54 changes: 52 additions & 2 deletions yb-voyager/test/containers/mysql_container.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package testcontainers

import (
"context"
"database/sql"
"fmt"
"os"
"time"
Expand All @@ -16,6 +17,7 @@ import (
type MysqlContainer struct {
ContainerConfig
container testcontainers.Container
db *sql.DB
}

func (ms *MysqlContainer) Start(ctx context.Context) (err error) {
Expand Down Expand Up @@ -59,14 +61,39 @@ func (ms *MysqlContainer) Start(ctx context.Context) (err error) {
ContainerRequest: req,
Started: true,
})
return err
if err != nil {
return err
}

dsn := ms.GetConnectionString()
db, err := sql.Open("mysql", dsn)
if err != nil {
return fmt.Errorf("failed to open mysql connection: %w", err)
}

if err = db.Ping(); err != nil {
db.Close()
return fmt.Errorf("failed to ping mysql after connection: %w", err)
}

// Store the DB connection for reuse
ms.db = db

return nil
}

func (ms *MysqlContainer) Terminate(ctx context.Context) {
if ms == nil {
return
}

// Close the DB connection if it exists
if ms.db != nil {
if err := ms.db.Close(); err != nil {
log.Errorf("failed to close mysql db connection: %v", err)
}
}

err := ms.container.Terminate(ctx)
if err != nil {
log.Errorf("failed to terminate mysql container: %v", err)
Expand Down Expand Up @@ -96,6 +123,29 @@ func (ms *MysqlContainer) GetConfig() ContainerConfig {
return ms.ContainerConfig
}

// GetConnectionString constructs and returns the MySQL DSN
func (ms *MysqlContainer) GetConnectionString() string {
panic("GetConnectionString() not implemented yet for mysql")
host, port, err := ms.GetHostPort()
if err != nil {
utils.ErrExit("failed to get host port for mysql connection string: %v", err)
}

// DSN format: user:password@tcp(host:port)/dbname
return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s",
ms.User, ms.Password, host, port, ms.DBName)
}

// ExecuteSqls executes a list of SQL statements using the persistent DB connection
func (ms *MysqlContainer) ExecuteSqls(sqls []string) error {
if ms.db == nil {
return fmt.Errorf("db connection not initialized for mysql container")
}

for _, sqlStmt := range sqls {
_, err := ms.db.Exec(sqlStmt)
if err != nil {
return fmt.Errorf("failed to execute sql '%s': %w", sqlStmt, err)
}
}
return nil
}
4 changes: 4 additions & 0 deletions yb-voyager/test/containers/oracle_container.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,7 @@ func (ora *OracleContainer) GetConfig() ContainerConfig {
func (ora *OracleContainer) GetConnectionString() string {
panic("GetConnectionString() not implemented yet for oracle")
}

func (ora *OracleContainer) ExecuteSqls(sqls []string) error {
return nil
}
18 changes: 18 additions & 0 deletions yb-voyager/test/containers/postgres_container.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"time"

"github.com/docker/go-connections/nat"
"github.com/jackc/pgx/v5"
log "github.com/sirupsen/logrus"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/wait"
Expand Down Expand Up @@ -125,3 +126,20 @@ func (pg *PostgresContainer) GetConnectionString() string {

return fmt.Sprintf("postgresql://%s:%s@%s:%d/%s", config.User, config.Password, host, port, config.DBName)
}

func (pg *PostgresContainer) ExecuteSqls(sqls []string) error {
connStr := pg.GetConnectionString()
conn, err := pgx.Connect(context.Background(), connStr)
if err != nil {
return fmt.Errorf("failed to connect postgres for executing sqls: %w", err)
}
defer conn.Close(context.Background())

for _, sql := range sqls {
_, err := conn.Exec(context.Background(), sql)
if err != nil {
return fmt.Errorf("failed to execute sql '%s': %w", sql, err)
}
}
return nil
}
1 change: 1 addition & 0 deletions yb-voyager/test/containers/testcontainers.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type TestContainer interface {
// Add Capability to run multiple versions of a dbtype parallely
*/
ExecuteSqls(sqls []string) error
}

type ContainerConfig struct {
Expand Down
18 changes: 18 additions & 0 deletions yb-voyager/test/containers/yugabytedb_container.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"time"

"github.com/docker/go-connections/nat"
"github.com/jackc/pgx/v5"
log "github.com/sirupsen/logrus"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/wait"
Expand Down Expand Up @@ -109,3 +110,20 @@ func (yb *YugabyteDBContainer) GetConnectionString() string {

return fmt.Sprintf("postgresql://%s:%s@%s:%d/%s", config.User, config.Password, host, port, config.DBName)
}

func (yb *YugabyteDBContainer) ExecuteSqls(sqls []string) error {
connStr := yb.GetConnectionString()
conn, err := pgx.Connect(context.Background(), connStr)
if err != nil {
return fmt.Errorf("failed to connect postgres for executing sqls: %w", err)
}
defer conn.Close(context.Background())

for _, sql := range sqls {
_, err := conn.Exec(context.Background(), sql)
if err != nil {
return fmt.Errorf("failed to execute sql '%s': %w", sql, err)
}
}
return nil
}

0 comments on commit 788ea74

Please sign in to comment.