From 5e6292cf6af6bd6471572bd981ba4077e3b940d9 Mon Sep 17 00:00:00 2001 From: Sanyam Singhal Date: Sat, 21 Dec 2024 16:27:24 +0000 Subject: [PATCH] Using ExecuteSqls() to setup schema/objects specific to each test instead of depending on a common global init sql script --- yb-voyager/src/srcdb/main_test.go | 2 +- yb-voyager/src/srcdb/postgres_test.go | 94 ++++++++++++++---- yb-voyager/src/srcdb/yugbaytedb_test.go | 95 +++++++++++++++---- yb-voyager/src/tgtdb/conn_pool_test.go | 8 +- yb-voyager/src/tgtdb/main_test.go | 7 +- yb-voyager/src/tgtdb/postgres_test.go | 58 ++++++++--- yb-voyager/src/tgtdb/yugabytedb_test.go | 57 +++++++++-- yb-voyager/test/containers/mysql_container.go | 9 +- .../test/containers/oracle_container.go | 4 +- .../test/containers/postgres_container.go | 43 ++++++--- .../test_schemas/postgresql_schema.sql | 44 +-------- .../test_schemas/yugabytedb_schema.sql | 42 -------- yb-voyager/test/containers/testcontainers.go | 2 +- .../test/containers/yugabytedb_container.go | 7 +- 14 files changed, 294 insertions(+), 178 deletions(-) diff --git a/yb-voyager/src/srcdb/main_test.go b/yb-voyager/src/srcdb/main_test.go index 8215c2c124..72d88e19e6 100644 --- a/yb-voyager/src/srcdb/main_test.go +++ b/yb-voyager/src/srcdb/main_test.go @@ -47,7 +47,6 @@ func TestMain(m *testing.M) { // setting source db type, version and defaults postgresContainer := testcontainers.NewTestContainer("postgresql", nil) - // TODO: handle error err := postgresContainer.Start(ctx) if err != nil { utils.ErrExit("Failed to start postgres container: %v", err) @@ -82,6 +81,7 @@ func TestMain(m *testing.M) { if err != nil { utils.ErrExit("%v", err) } + testOracleSource = &TestDB{ TestContainer: oracleContainer, Source: &Source{ diff --git a/yb-voyager/src/srcdb/postgres_test.go b/yb-voyager/src/srcdb/postgres_test.go index 6641b5d353..d25a19005a 100644 --- a/yb-voyager/src/srcdb/postgres_test.go +++ b/yb-voyager/src/srcdb/postgres_test.go @@ -26,43 +26,79 @@ import ( ) func TestPostgresGetAllTableNames(t *testing.T) { + testPostgresSource.TestContainer.ExecuteSqls( + `CREATE SCHEMA test_schema;`, + `CREATE TABLE test_schema.foo ( + id INT PRIMARY KEY, + name VARCHAR + );`, + `INSERT into test_schema.foo values (1, 'abc'), (2, 'xyz');`, + `CREATE TABLE test_schema.bar ( + id INT PRIMARY KEY, + name VARCHAR + );`, + `INSERT into test_schema.bar values (1, 'abc'), (2, 'xyz');`, + `CREATE TABLE test_schema.non_pk1( + id INT, + name VARCHAR(255) + );`) + defer testPostgresSource.TestContainer.ExecuteSqls(`DROP SCHEMA test_schema CASCADE;`) + sqlname.SourceDBType = "postgresql" + testPostgresSource.Source.Schema = "test_schema" // Test GetAllTableNames actualTables := testPostgresSource.DB().GetAllTableNames() expectedTables := []*sqlname.SourceName{ - sqlname.NewSourceName("public", "foo"), - sqlname.NewSourceName("public", "bar"), - sqlname.NewSourceName("public", "table1"), - sqlname.NewSourceName("public", "table2"), - sqlname.NewSourceName("public", "unique_table"), - sqlname.NewSourceName("public", "non_pk1"), - sqlname.NewSourceName("public", "non_pk2"), + sqlname.NewSourceName("test_schema", "foo"), + sqlname.NewSourceName("test_schema", "bar"), + sqlname.NewSourceName("test_schema", "non_pk1"), } assert.Equal(t, len(expectedTables), len(actualTables), "Expected number of tables to match") - testutils.AssertEqualSourceNameSlices(t, expectedTables, actualTables) } func TestPostgresGetTableToUniqueKeyColumnsMap(t *testing.T) { - objectName := sqlname.NewObjectName("postgresql", "public", "public", "unique_table") - - // Test GetTableToUniqueKeyColumnsMap - tableList := []sqlname.NameTuple{ - {CurrentName: objectName}, + testPostgresSource.TestContainer.ExecuteSqls( + `CREATE SCHEMA test_schema;`, + `CREATE TABLE test_schema.unique_table ( + id SERIAL PRIMARY KEY, + email VARCHAR(255) UNIQUE, + phone VARCHAR(20) UNIQUE, + address VARCHAR(255) UNIQUE + );`, + `INSERT INTO test_schema.unique_table (email, phone, address) VALUES + ('john@example.com', '1234567890', '123 Elm Street'), + ('jane@example.com', '0987654321', '456 Oak Avenue');`, + `CREATE TABLE test_schema.another_unique_table ( + user_id SERIAL PRIMARY KEY, + username VARCHAR(50) UNIQUE, + age INT + );`, + `CREATE UNIQUE INDEX idx_age ON test_schema.another_unique_table(age);`, + `INSERT INTO test_schema.another_unique_table (username, age) VALUES + ('user1', 30), + ('user2', 40);`) + defer testPostgresSource.TestContainer.ExecuteSqls(`DROP SCHEMA test_schema CASCADE;`) + + uniqueTablesList := []sqlname.NameTuple{ + {CurrentName: sqlname.NewObjectName("postgresql", "test_schema", "test_schema", "unique_table")}, + {CurrentName: sqlname.NewObjectName("postgresql", "test_schema", "test_schema", "another_unique_table")}, } - uniqueKeys, err := testPostgresSource.DB().GetTableToUniqueKeyColumnsMap(tableList) + + actualUniqKeys, err := testPostgresSource.DB().GetTableToUniqueKeyColumnsMap(uniqueTablesList) if err != nil { t.Fatalf("Error retrieving unique keys: %v", err) } - expectedKeys := map[string][]string{ - "unique_table": {"email", "phone", "address"}, + expectedUniqKeys := map[string][]string{ + "test_schema.unique_table": {"email", "phone", "address"}, + "test_schema.another_unique_table": {"username", "age"}, } // Compare the maps by iterating over each table and asserting the columns list - for table, expectedColumns := range expectedKeys { - actualColumns, exists := uniqueKeys[table] + for table, expectedColumns := range expectedUniqKeys { + actualColumns, exists := actualUniqKeys[table] if !exists { t.Errorf("Expected table %s not found in uniqueKeys", table) } @@ -72,9 +108,29 @@ func TestPostgresGetTableToUniqueKeyColumnsMap(t *testing.T) { } func TestPostgresGetNonPKTables(t *testing.T) { + testPostgresSource.TestContainer.ExecuteSqls( + `CREATE SCHEMA test_schema;`, + `CREATE TABLE test_schema.table1 ( + id SERIAL PRIMARY KEY, + name VARCHAR(100) + );`, + `CREATE TABLE test_schema.table2 ( + id SERIAL PRIMARY KEY, + email VARCHAR(100) + );`, + `CREATE TABLE test_schema.non_pk1( + id INT, + name VARCHAR(255) + );`, + `CREATE TABLE test_schema.non_pk2( + id INT, + name VARCHAR(255) + );`) + defer testPostgresSource.TestContainer.ExecuteSqls(`DROP SCHEMA test_schema CASCADE;`) + actualTables, err := testPostgresSource.DB().GetNonPKTables() assert.NilError(t, err, "Expected nil but non nil error: %v", err) - expectedTables := []string{`public."non_pk2"`, `public."non_pk1"`} // func returns table.Qualified.Quoted + expectedTables := []string{`test_schema."non_pk2"`, `test_schema."non_pk1"`} // func returns table.Qualified.Quoted testutils.AssertEqualStringSlices(t, expectedTables, actualTables) } diff --git a/yb-voyager/src/srcdb/yugbaytedb_test.go b/yb-voyager/src/srcdb/yugbaytedb_test.go index 6d414d2ee4..4a7cf8e6f9 100644 --- a/yb-voyager/src/srcdb/yugbaytedb_test.go +++ b/yb-voyager/src/srcdb/yugbaytedb_test.go @@ -26,42 +26,79 @@ import ( ) func TestYugabyteGetAllTableNames(t *testing.T) { - sqlname.SourceDBType = "yugabytedb" + testYugabyteDBSource.TestContainer.ExecuteSqls( + `CREATE SCHEMA test_schema;`, + `CREATE TABLE test_schema.foo ( + id INT PRIMARY KEY, + name VARCHAR + );`, + `INSERT into test_schema.foo values (1, 'abc'), (2, 'xyz');`, + `CREATE TABLE test_schema.bar ( + id INT PRIMARY KEY, + name VARCHAR + );`, + `INSERT into test_schema.bar values (1, 'abc'), (2, 'xyz');`, + `CREATE TABLE test_schema.non_pk1( + id INT, + name VARCHAR(255) + );`) + defer testYugabyteDBSource.TestContainer.ExecuteSqls(`DROP SCHEMA test_schema CASCADE;`) + + sqlname.SourceDBType = "postgresql" + testYugabyteDBSource.Source.Schema = "test_schema" // Test GetAllTableNames actualTables := testYugabyteDBSource.DB().GetAllTableNames() expectedTables := []*sqlname.SourceName{ - sqlname.NewSourceName("public", "foo"), - sqlname.NewSourceName("public", "bar"), - sqlname.NewSourceName("public", "table1"), - sqlname.NewSourceName("public", "table2"), - sqlname.NewSourceName("public", "unique_table"), - sqlname.NewSourceName("public", "non_pk1"), - sqlname.NewSourceName("public", "non_pk2"), + sqlname.NewSourceName("test_schema", "foo"), + sqlname.NewSourceName("test_schema", "bar"), + sqlname.NewSourceName("test_schema", "non_pk1"), } assert.Equal(t, len(expectedTables), len(actualTables), "Expected number of tables to match") - testutils.AssertEqualSourceNameSlices(t, expectedTables, actualTables) } func TestYugabyteGetTableToUniqueKeyColumnsMap(t *testing.T) { - objectName := sqlname.NewObjectName("yugabytedb", "public", "public", "unique_table") + testYugabyteDBSource.TestContainer.ExecuteSqls( + `CREATE SCHEMA test_schema;`, + `CREATE TABLE test_schema.unique_table ( + id SERIAL PRIMARY KEY, + email VARCHAR(255) UNIQUE, + phone VARCHAR(20) UNIQUE, + address VARCHAR(255) UNIQUE + );`, + `INSERT INTO test_schema.unique_table (email, phone, address) VALUES + ('john@example.com', '1234567890', '123 Elm Street'), + ('jane@example.com', '0987654321', '456 Oak Avenue');`, + `CREATE TABLE test_schema.another_unique_table ( + user_id SERIAL PRIMARY KEY, + username VARCHAR(50) UNIQUE, + age INT + );`, + `CREATE UNIQUE INDEX idx_age ON test_schema.another_unique_table(age);`, + `INSERT INTO test_schema.another_unique_table (username, age) VALUES + ('user1', 30), + ('user2', 40);`) + defer testYugabyteDBSource.TestContainer.ExecuteSqls(`DROP SCHEMA test_schema CASCADE;`) - // Test GetTableToUniqueKeyColumnsMap - tableList := []sqlname.NameTuple{ - {CurrentName: objectName}, + uniqueTablesList := []sqlname.NameTuple{ + {CurrentName: sqlname.NewObjectName("postgresql", "test_schema", "test_schema", "unique_table")}, + {CurrentName: sqlname.NewObjectName("postgresql", "test_schema", "test_schema", "another_unique_table")}, } - uniqueKeys, err := testYugabyteDBSource.DB().GetTableToUniqueKeyColumnsMap(tableList) + + actualUniqKeys, err := testYugabyteDBSource.DB().GetTableToUniqueKeyColumnsMap(uniqueTablesList) if err != nil { t.Fatalf("Error retrieving unique keys: %v", err) } - expectedKeys := map[string][]string{ - "unique_table": {"email", "phone", "address"}, + expectedUniqKeys := map[string][]string{ + "test_schema.unique_table": {"email", "phone", "address"}, + "test_schema.another_unique_table": {"username", "age"}, } + // Compare the maps by iterating over each table and asserting the columns list - for table, expectedColumns := range expectedKeys { - actualColumns, exists := uniqueKeys[table] + for table, expectedColumns := range expectedUniqKeys { + actualColumns, exists := actualUniqKeys[table] if !exists { t.Errorf("Expected table %s not found in uniqueKeys", table) } @@ -71,9 +108,29 @@ func TestYugabyteGetTableToUniqueKeyColumnsMap(t *testing.T) { } func TestYugabyteGetNonPKTables(t *testing.T) { + testYugabyteDBSource.TestContainer.ExecuteSqls( + `CREATE SCHEMA test_schema;`, + `CREATE TABLE test_schema.table1 ( + id SERIAL PRIMARY KEY, + name VARCHAR(100) + );`, + `CREATE TABLE test_schema.table2 ( + id SERIAL PRIMARY KEY, + email VARCHAR(100) + );`, + `CREATE TABLE test_schema.non_pk1( + id INT, + name VARCHAR(255) + );`, + `CREATE TABLE test_schema.non_pk2( + id INT, + name VARCHAR(255) + );`) + defer testYugabyteDBSource.TestContainer.ExecuteSqls(`DROP SCHEMA test_schema CASCADE;`) + actualTables, err := testYugabyteDBSource.DB().GetNonPKTables() assert.NilError(t, err, "Expected nil but non nil error: %v", err) - expectedTables := []string{`public."non_pk2"`, `public."non_pk1"`} // func returns table.Qualified.Quoted + expectedTables := []string{`test_schema."non_pk2"`, `test_schema."non_pk1"`} // func returns table.Qualified.Quoted testutils.AssertEqualStringSlices(t, expectedTables, actualTables) } diff --git a/yb-voyager/src/tgtdb/conn_pool_test.go b/yb-voyager/src/tgtdb/conn_pool_test.go index 374bb88d91..7d4ce43805 100644 --- a/yb-voyager/src/tgtdb/conn_pool_test.go +++ b/yb-voyager/src/tgtdb/conn_pool_test.go @@ -34,7 +34,7 @@ func TestBasic(t *testing.T) { connParams := &ConnectionParams{ NumConnections: size, NumMaxConnections: size, - ConnUriList: []string{testYugabyteDBTarget.Container.GetConnectionString()}, + ConnUriList: []string{testYugabyteDBTarget.GetConnectionString()}, SessionInitScript: []string{}, } pool := NewConnectionPool(connParams) @@ -69,7 +69,7 @@ func TestIncreaseConnectionsUptoMax(t *testing.T) { connParams := &ConnectionParams{ NumConnections: size, NumMaxConnections: maxSize, - ConnUriList: []string{testYugabyteDBTarget.Container.GetConnectionString()}, + ConnUriList: []string{testYugabyteDBTarget.GetConnectionString()}, SessionInitScript: []string{}, } pool := NewConnectionPool(connParams) @@ -112,7 +112,7 @@ func TestDecreaseConnectionsUptoMin(t *testing.T) { connParams := &ConnectionParams{ NumConnections: size, NumMaxConnections: maxSize, - ConnUriList: []string{testYugabyteDBTarget.Container.GetConnectionString()}, + ConnUriList: []string{testYugabyteDBTarget.GetConnectionString()}, SessionInitScript: []string{}, } pool := NewConnectionPool(connParams) @@ -155,7 +155,7 @@ func TestUpdateConnectionsRandom(t *testing.T) { connParams := &ConnectionParams{ NumConnections: size, NumMaxConnections: maxSize, - ConnUriList: []string{testYugabyteDBTarget.Container.GetConnectionString()}, + ConnUriList: []string{testYugabyteDBTarget.GetConnectionString()}, SessionInitScript: []string{}, } pool := NewConnectionPool(connParams) diff --git a/yb-voyager/src/tgtdb/main_test.go b/yb-voyager/src/tgtdb/main_test.go index 0d6be93816..9daf45d1cd 100644 --- a/yb-voyager/src/tgtdb/main_test.go +++ b/yb-voyager/src/tgtdb/main_test.go @@ -28,7 +28,7 @@ import ( ) type TestDB struct { - Container testcontainers.TestContainer + testcontainers.TestContainer TargetDB } @@ -42,7 +42,6 @@ func TestMain(m *testing.M) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - // setting source db type, version and defaults postgresContainer := testcontainers.NewTestContainer("postgresql", nil) err := postgresContainer.Start(ctx) if err != nil { @@ -53,7 +52,7 @@ func TestMain(m *testing.M) { utils.ErrExit("%v", err) } testPostgresTarget = &TestDB{ - Container: postgresContainer, + TestContainer: postgresContainer, TargetDB: NewTargetDB(&TargetConf{ TargetDBType: "postgresql", DBVersion: postgresContainer.GetConfig().DBVersion, @@ -109,7 +108,7 @@ func TestMain(m *testing.M) { utils.ErrExit("%v", err) } testYugabyteDBTarget = &TestDB{ - Container: yugabytedbContainer, + TestContainer: yugabytedbContainer, TargetDB: NewTargetDB(&TargetConf{ TargetDBType: "yugabytedb", DBVersion: yugabytedbContainer.GetConfig().DBVersion, diff --git a/yb-voyager/src/tgtdb/postgres_test.go b/yb-voyager/src/tgtdb/postgres_test.go index fd093e77a4..e0e9bd60c5 100644 --- a/yb-voyager/src/tgtdb/postgres_test.go +++ b/yb-voyager/src/tgtdb/postgres_test.go @@ -28,7 +28,7 @@ import ( ) func TestCreateVoyagerSchemaPG(t *testing.T) { - db, err := sql.Open("pgx", testPostgresTarget.Container.GetConnectionString()) + db, err := sql.Open("pgx", testPostgresTarget.GetConnectionString()) assert.NoError(t, err) defer db.Close() @@ -88,22 +88,58 @@ func TestCreateVoyagerSchemaPG(t *testing.T) { } func TestPostgresGetNonEmptyTables(t *testing.T) { + testPostgresTarget.ExecuteSqls( + `CREATE SCHEMA test_schema`, + `CREATE TABLE test_schema.foo ( + id INT PRIMARY KEY, + name VARCHAR + );`, + `INSERT into test_schema.foo values (1, 'abc'), (2, 'xyz');`, + `CREATE TABLE test_schema.bar ( + id INT PRIMARY KEY, + name VARCHAR + );`, + `INSERT into test_schema.bar values (1, 'abc'), (2, 'xyz');`, + `CREATE TABLE test_schema.unique_table ( + id SERIAL PRIMARY KEY, + email VARCHAR(100), + phone VARCHAR(100), + address VARCHAR(255), + UNIQUE (email, phone) -- Unique constraint on combination of columns + );`, + `CREATE TABLE test_schema.table1 ( + id SERIAL PRIMARY KEY, + name VARCHAR(100) + );`, + `CREATE TABLE test_schema.table2 ( + id SERIAL PRIMARY KEY, + email VARCHAR(100) + );`, + `CREATE TABLE test_schema.non_pk1( + id INT, + name VARCHAR(255) + );`, + `CREATE TABLE test_schema.non_pk2( + id INT, + name VARCHAR(255) + );`) + defer testPostgresTarget.ExecuteSqls(`DROP SCHEMA test_schema CASCADE;`) + tables := []sqlname.NameTuple{ - {CurrentName: sqlname.NewObjectName(POSTGRESQL, "public", "public", "foo")}, - {CurrentName: sqlname.NewObjectName(POSTGRESQL, "public", "public", "bar")}, - {CurrentName: sqlname.NewObjectName(POSTGRESQL, "public", "public", "unique_table")}, - {CurrentName: sqlname.NewObjectName(POSTGRESQL, "public", "public", "table1")}, - {CurrentName: sqlname.NewObjectName(POSTGRESQL, "public", "public", "table2")}, - {CurrentName: sqlname.NewObjectName(POSTGRESQL, "public", "public", "non_pk1")}, - {CurrentName: sqlname.NewObjectName(POSTGRESQL, "public", "public", "non_pk2")}, + {CurrentName: sqlname.NewObjectName(POSTGRESQL, "test_schema", "test_schema", "foo")}, + {CurrentName: sqlname.NewObjectName(POSTGRESQL, "test_schema", "test_schema", "bar")}, + {CurrentName: sqlname.NewObjectName(POSTGRESQL, "test_schema", "test_schema", "unique_table")}, + {CurrentName: sqlname.NewObjectName(POSTGRESQL, "test_schema", "test_schema", "table1")}, + {CurrentName: sqlname.NewObjectName(POSTGRESQL, "test_schema", "test_schema", "table2")}, + {CurrentName: sqlname.NewObjectName(POSTGRESQL, "test_schema", "test_schema", "non_pk1")}, + {CurrentName: sqlname.NewObjectName(POSTGRESQL, "test_schema", "test_schema", "non_pk2")}, } expectedTables := []sqlname.NameTuple{ - {CurrentName: sqlname.NewObjectName(POSTGRESQL, "public", "public", "foo")}, - {CurrentName: sqlname.NewObjectName(POSTGRESQL, "public", "public", "bar")}, + {CurrentName: sqlname.NewObjectName(POSTGRESQL, "test_schema", "test_schema", "foo")}, + {CurrentName: sqlname.NewObjectName(POSTGRESQL, "test_schema", "test_schema", "bar")}, } actualTables := testPostgresTarget.GetNonEmptyTables(tables) - fmt.Printf("non empty tables: %+v\n", actualTables) testutils.AssertEqualNameTuplesSlice(t, expectedTables, actualTables) } diff --git a/yb-voyager/src/tgtdb/yugabytedb_test.go b/yb-voyager/src/tgtdb/yugabytedb_test.go index 0f02329f33..c9ac66ce2f 100644 --- a/yb-voyager/src/tgtdb/yugabytedb_test.go +++ b/yb-voyager/src/tgtdb/yugabytedb_test.go @@ -28,7 +28,7 @@ import ( ) func TestCreateVoyagerSchemaYB(t *testing.T) { - db, err := sql.Open("pgx", testYugabyteDBTarget.Container.GetConnectionString()) + db, err := sql.Open("pgx", testYugabyteDBTarget.GetConnectionString()) assert.NoError(t, err) defer db.Close() @@ -88,19 +88,56 @@ func TestCreateVoyagerSchemaYB(t *testing.T) { } func TestYugabyteGetNonEmptyTables(t *testing.T) { + testYugabyteDBTarget.ExecuteSqls( + `CREATE SCHEMA test_schema`, + `CREATE TABLE test_schema.foo ( + id INT PRIMARY KEY, + name VARCHAR + );`, + `INSERT into test_schema.foo values (1, 'abc'), (2, 'xyz');`, + `CREATE TABLE test_schema.bar ( + id INT PRIMARY KEY, + name VARCHAR + );`, + `INSERT into test_schema.bar values (1, 'abc'), (2, 'xyz');`, + `CREATE TABLE test_schema.unique_table ( + id SERIAL PRIMARY KEY, + email VARCHAR(100), + phone VARCHAR(100), + address VARCHAR(255), + UNIQUE (email, phone) + );`, + `CREATE TABLE test_schema.table1 ( + id SERIAL PRIMARY KEY, + name VARCHAR(100) + );`, + `CREATE TABLE test_schema.table2 ( + id SERIAL PRIMARY KEY, + email VARCHAR(100) + );`, + `CREATE TABLE test_schema.non_pk1( + id INT, + name VARCHAR(255) + );`, + `CREATE TABLE test_schema.non_pk2( + id INT, + name VARCHAR(255) + );`) + defer testYugabyteDBTarget.ExecuteSqls(`DROP SCHEMA test_schema CASCADE;`) + tables := []sqlname.NameTuple{ - {CurrentName: sqlname.NewObjectName(YUGABYTEDB, "public", "public", "foo")}, - {CurrentName: sqlname.NewObjectName(YUGABYTEDB, "public", "public", "bar")}, - {CurrentName: sqlname.NewObjectName(YUGABYTEDB, "public", "public", "unique_table")}, - {CurrentName: sqlname.NewObjectName(YUGABYTEDB, "public", "public", "table1")}, - {CurrentName: sqlname.NewObjectName(YUGABYTEDB, "public", "public", "table2")}, - {CurrentName: sqlname.NewObjectName(YUGABYTEDB, "public", "public", "non_pk1")}, - {CurrentName: sqlname.NewObjectName(YUGABYTEDB, "public", "public", "non_pk2")}, + {CurrentName: sqlname.NewObjectName(YUGABYTEDB, "test_schema", "test_schema", "foo")}, + {CurrentName: sqlname.NewObjectName(YUGABYTEDB, "test_schema", "test_schema", "bar")}, + {CurrentName: sqlname.NewObjectName(YUGABYTEDB, "test_schema", "test_schema", "unique_table")}, + {CurrentName: sqlname.NewObjectName(YUGABYTEDB, "test_schema", "test_schema", "table1")}, + {CurrentName: sqlname.NewObjectName(YUGABYTEDB, "test_schema", "test_schema", "table2")}, + {CurrentName: sqlname.NewObjectName(YUGABYTEDB, "test_schema", "test_schema", "non_pk1")}, + {CurrentName: sqlname.NewObjectName(YUGABYTEDB, "test_schema", "test_schema", "non_pk2")}, } expectedTables := []sqlname.NameTuple{ - {CurrentName: sqlname.NewObjectName(YUGABYTEDB, "public", "public", "foo")}, - {CurrentName: sqlname.NewObjectName(YUGABYTEDB, "public", "public", "bar")}, + {CurrentName: sqlname.NewObjectName(YUGABYTEDB, "test_schema", "test_schema", "foo")}, + {CurrentName: sqlname.NewObjectName(YUGABYTEDB, "test_schema", "test_schema", "bar")}, } actualTables := testYugabyteDBTarget.GetNonEmptyTables(tables) diff --git a/yb-voyager/test/containers/mysql_container.go b/yb-voyager/test/containers/mysql_container.go index 810347b969..5c6aa5114c 100644 --- a/yb-voyager/test/containers/mysql_container.go +++ b/yb-voyager/test/containers/mysql_container.go @@ -123,7 +123,6 @@ func (ms *MysqlContainer) GetConfig() ContainerConfig { return ms.ContainerConfig } -// GetConnectionString constructs and returns the MySQL DSN func (ms *MysqlContainer) GetConnectionString() string { host, port, err := ms.GetHostPort() if err != nil { @@ -135,17 +134,15 @@ func (ms *MysqlContainer) GetConnectionString() string { ms.User, ms.Password, host, port, ms.DBName) } -// ExecuteSqls executes a list of SQL statements using the persistent DB connection -func (ms *MysqlContainer) ExecuteSqls(sqls []string) error { +func (ms *MysqlContainer) ExecuteSqls(sqls ...string) { if ms.db == nil { - return fmt.Errorf("db connection not initialized for mysql container") + utils.ErrExit("db connection not initialized for mysql container") } for _, sqlStmt := range sqls { _, err := ms.db.Exec(sqlStmt) if err != nil { - return fmt.Errorf("failed to execute sql '%s': %w", sqlStmt, err) + utils.ErrExit("failed to execute sql '%s': %w", sqlStmt, err) } } - return nil } diff --git a/yb-voyager/test/containers/oracle_container.go b/yb-voyager/test/containers/oracle_container.go index fab39893ae..19dd50b83f 100644 --- a/yb-voyager/test/containers/oracle_container.go +++ b/yb-voyager/test/containers/oracle_container.go @@ -101,6 +101,6 @@ func (ora *OracleContainer) GetConnectionString() string { panic("GetConnectionString() not implemented yet for oracle") } -func (ora *OracleContainer) ExecuteSqls(sqls []string) error { - return nil +func (ora *OracleContainer) ExecuteSqls(sqls ...string) { + } diff --git a/yb-voyager/test/containers/postgres_container.go b/yb-voyager/test/containers/postgres_container.go index 74006354e7..afe318e90e 100644 --- a/yb-voyager/test/containers/postgres_container.go +++ b/yb-voyager/test/containers/postgres_container.go @@ -2,13 +2,13 @@ package testcontainers import ( "context" + "database/sql" "fmt" "io" "os" "time" "github.com/docker/go-connections/nat" - "github.com/jackc/pgx/v5" log "github.com/sirupsen/logrus" "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/wait" @@ -18,6 +18,7 @@ import ( type PostgresContainer struct { ContainerConfig container testcontainers.Container + db *sql.DB } func (pg *PostgresContainer) Start(ctx context.Context) (err error) { @@ -78,9 +79,24 @@ func (pg *PostgresContainer) Start(ctx context.Context) (err error) { fmt.Println("=== End of Logs ===") } } + return err } - return err + dsn := pg.GetConnectionString() + db, err := sql.Open("pgx", dsn) + if err != nil { + return fmt.Errorf("failed to open postgres connection: %w", err) + } + + if err := db.Ping(); err != nil { + db.Close() + pg.container.Terminate(ctx) + return fmt.Errorf("failed to ping postgres after connection: %w", err) + } + + // Store the DB connection for reuse + pg.db = db + return nil } func (pg *PostgresContainer) Terminate(ctx context.Context) { @@ -88,6 +104,13 @@ func (pg *PostgresContainer) Terminate(ctx context.Context) { return } + // Close the DB connection if it exists + if pg.db != nil { + if err := pg.db.Close(); err != nil { + log.Errorf("failed to close postgres db connection: %v", err) + } + } + err := pg.container.Terminate(ctx) if err != nil { log.Errorf("failed to terminate postgres container: %v", err) @@ -127,19 +150,15 @@ func (pg *PostgresContainer) GetConnectionString() string { return fmt.Sprintf("postgresql://%s:%s@%s:%d/%s", config.User, config.Password, host, port, config.DBName) } -func (pg *PostgresContainer) ExecuteSqls(sqls []string) error { - connStr := pg.GetConnectionString() - conn, err := pgx.Connect(context.Background(), connStr) - if err != nil { - return fmt.Errorf("failed to connect postgres for executing sqls: %w", err) +func (pg *PostgresContainer) ExecuteSqls(sqls ...string) { + if pg.db == nil { + utils.ErrExit("db connection not initialized for postgres container") } - defer conn.Close(context.Background()) - for _, sql := range sqls { - _, err := conn.Exec(context.Background(), sql) + for _, sqlStmt := range sqls { + _, err := pg.db.Exec(sqlStmt) if err != nil { - return fmt.Errorf("failed to execute sql '%s': %w", sql, err) + utils.ErrExit("failed to execute sql '%s': %w", sqlStmt, err) } } - return nil } diff --git a/yb-voyager/test/containers/test_schemas/postgresql_schema.sql b/yb-voyager/test/containers/test_schemas/postgresql_schema.sql index 4b81783600..36bda657a5 100644 --- a/yb-voyager/test/containers/test_schemas/postgresql_schema.sql +++ b/yb-voyager/test/containers/test_schemas/postgresql_schema.sql @@ -1,43 +1 @@ --- TODO: create user as per User creation steps in docs and use that in tests - -CREATE TABLE public.foo ( - id INT PRIMARY KEY, - name VARCHAR -); -INSERT into public.foo values (1, 'abc'), (2, 'xyz'); - -CREATE TABLE public.bar ( - id INT PRIMARY KEY, - name VARCHAR -); -INSERT into public.bar values (1, 'abc'), (2, 'xyz'); - -CREATE TABLE public.unique_table ( - id SERIAL PRIMARY KEY, - email VARCHAR(100), - phone VARCHAR(100), - address VARCHAR(255), - UNIQUE (email, phone) -- Unique constraint on combination of columns -); - -CREATE UNIQUE INDEX unique_address_idx ON public.unique_table (address); -- Unique Index - -CREATE TABLE public.table1 ( - id SERIAL PRIMARY KEY, - name VARCHAR(100) -); - -CREATE TABLE public.table2 ( - id SERIAL PRIMARY KEY, - email VARCHAR(100) -); - -CREATE TABLE public.non_pk1( - id INT, - name VARCHAR(255) -); - -CREATE TABLE public.non_pk2( - id INT, - name VARCHAR(255) -); \ No newline at end of file +-- TODO: create source migration user as per User creation steps in docs and use that in tests diff --git a/yb-voyager/test/containers/test_schemas/yugabytedb_schema.sql b/yb-voyager/test/containers/test_schemas/yugabytedb_schema.sql index a260d5a4d4..c36ddc5b93 100644 --- a/yb-voyager/test/containers/test_schemas/yugabytedb_schema.sql +++ b/yb-voyager/test/containers/test_schemas/yugabytedb_schema.sql @@ -1,43 +1 @@ -- TODO: create user as per User creation steps in docs and use that in tests - -CREATE TABLE public.foo ( - id INT PRIMARY KEY, - name VARCHAR -); -INSERT into public.foo values (1, 'abc'), (2, 'xyz'); - -CREATE TABLE public.bar ( - id INT PRIMARY KEY, - name VARCHAR -); -INSERT into public.bar values (1, 'abc'), (2, 'xyz'); - -CREATE TABLE public.unique_table ( - id SERIAL PRIMARY KEY, - email VARCHAR(100), - phone VARCHAR(100), - address VARCHAR(255), - UNIQUE (email, phone) -- Unique constraint on combination of columns -); - -CREATE UNIQUE INDEX unique_address_idx ON public.unique_table (address); - -CREATE TABLE public.table1 ( - id SERIAL PRIMARY KEY, - name VARCHAR(100) -); - -CREATE TABLE public.table2 ( - id SERIAL PRIMARY KEY, - email VARCHAR(100) -); - -CREATE TABLE public.non_pk1( - id INT, - name VARCHAR(255) -); - -CREATE TABLE public.non_pk2( - id INT, - name VARCHAR(255) -); \ No newline at end of file diff --git a/yb-voyager/test/containers/testcontainers.go b/yb-voyager/test/containers/testcontainers.go index 318e07e3ba..64439d31a4 100644 --- a/yb-voyager/test/containers/testcontainers.go +++ b/yb-voyager/test/containers/testcontainers.go @@ -28,7 +28,7 @@ type TestContainer interface { // Add Capability to run multiple versions of a dbtype parallely */ - ExecuteSqls(sqls []string) error + ExecuteSqls(sqls ...string) } type ContainerConfig struct { diff --git a/yb-voyager/test/containers/yugabytedb_container.go b/yb-voyager/test/containers/yugabytedb_container.go index 9ca14cde67..08921b56f4 100644 --- a/yb-voyager/test/containers/yugabytedb_container.go +++ b/yb-voyager/test/containers/yugabytedb_container.go @@ -111,19 +111,18 @@ func (yb *YugabyteDBContainer) GetConnectionString() string { return fmt.Sprintf("postgresql://%s:%s@%s:%d/%s", config.User, config.Password, host, port, config.DBName) } -func (yb *YugabyteDBContainer) ExecuteSqls(sqls []string) error { +func (yb *YugabyteDBContainer) ExecuteSqls(sqls ...string) { connStr := yb.GetConnectionString() conn, err := pgx.Connect(context.Background(), connStr) if err != nil { - return fmt.Errorf("failed to connect postgres for executing sqls: %w", err) + utils.ErrExit("failed to connect postgres for executing sqls: %w", err) } defer conn.Close(context.Background()) for _, sql := range sqls { _, err := conn.Exec(context.Background(), sql) if err != nil { - return fmt.Errorf("failed to execute sql '%s': %w", sql, err) + utils.ErrExit("failed to execute sql '%s': %w", sql, err) } } - return nil }