Skip to content
This repository was archived by the owner on Oct 9, 2023. It is now read-only.
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions database/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ import (
)

const (
database = "database"
postgres = "postgres"
database = "database"
postgresText = "postgres"
)

//go:generate pflags DbConfig --default-var=defaultConfig
Expand All @@ -19,9 +19,9 @@ var defaultConfig = &DbConfig{
ConnMaxLifeTime: config.Duration{Duration: time.Hour},
Postgres: PostgresConfig{
Port: 5432,
User: postgres,
Host: postgres,
DbName: postgres,
User: postgresText,
Host: postgresText,
DbName: postgresText,
ExtraOptions: "sslmode=disable",
},
}
Expand Down Expand Up @@ -53,6 +53,7 @@ type DbConfig struct {
ConnMaxLifeTime config.Duration `json:"connMaxLifeTime" pflag:",sets the maximum amount of time a connection may be reused"`
Postgres PostgresConfig `json:"postgres,omitempty"`
SQLite SQLiteConfig `json:"sqlite,omitempty"`
Mysql MysqlConfig `json:"mysql,omitempty"`
}

// SQLiteConfig can be used to configure
Expand All @@ -73,8 +74,18 @@ type PostgresConfig struct {
Debug bool `json:"debug" pflag:" Whether or not to start the database connection with debug mode enabled."`
}

type MysqlConfig struct {
Host string `json:"host" pflag:",The host name of the database server"`
Port int `json:"port" pflag:",The port name of the database server"`
DbName string `json:"dbname" pflag:",The database name"`
User string `json:"username" pflag:",The database user who is connecting to the server."`
// Either Password or PasswordPath must be set.
Password string `json:"password" pflag:",The database password."`
}

var emptySQLiteCfg = SQLiteConfig{}
var emptyPostgresConfig = PostgresConfig{}
var emptyMysqlConfig = MysqlConfig{}

func (s SQLiteConfig) IsEmpty() bool {
return s == emptySQLiteCfg
Expand All @@ -84,6 +95,10 @@ func (s PostgresConfig) IsEmpty() bool {
return s == emptyPostgresConfig
}

func (s MysqlConfig) IsEmpty() bool {
return s == emptyMysqlConfig
}

func GetConfig() *DbConfig {
databaseConfig := configSection.GetConfig().(*DbConfig)
if len(databaseConfig.DeprecatedHost) > 0 {
Expand Down
80 changes: 80 additions & 0 deletions database/mysql.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package database

import (
"context"
"errors"
"fmt"
"github.com/flyteorg/flytestdlib/logger"
mysql_driver "github.com/go-sql-driver/mysql"
"gorm.io/driver/mysql"
"gorm.io/gorm"
)

const dbNotExists uint16 = 1049

// Produces the DSN (data source name) for mysql connections
// Example: "user:pass@tcp(127.0.0.1:3306)/dbname?charset=utf8mb4&parseTime=True&loc=Local"
func getMysqlDsn(ctx context.Context, mysqlConfig MysqlConfig) string {
// Add reading from a file in the future
sqlConfig := mysql_driver.Config{
User: mysqlConfig.User,
Passwd: mysqlConfig.Password,
Net: "tcp",
Addr: fmt.Sprintf("%s:%d", mysqlConfig.Host, mysqlConfig.Port),
DBName: mysqlConfig.DbName,
Collation: "",
AllowCleartextPasswords: true,
MultiStatements: true,
ParseTime: true,
}
return sqlConfig.FormatDSN()
}

func CreateMysqlDbIfNotExists(ctx context.Context, gormConfig *gorm.Config, mysqlConfig MysqlConfig) (*gorm.DB, error) {
dsn := getMysqlDsn(ctx, mysqlConfig)
db, err := gorm.Open(mysql.Open(dsn), gormConfig)
if err != nil {
if isMysqlErrorWithCode(err, dbNotExists) {
logger.Infof(ctx, "Creating database %v", mysqlConfig.DbName)
withDefaultDB := MysqlConfig{
Host: mysqlConfig.Host,
Port: mysqlConfig.Port,
DbName: "mysql", // should always exist
User: mysqlConfig.User,
Password: mysqlConfig.Password,
}
dsn := getMysqlDsn(ctx, withDefaultDB)

defaultDB, err := gorm.Open(mysql.New(mysql.Config{
DSN: dsn,
DefaultStringSize: 100,
// TODO: do we need to set DefaultDatetimePrecision? What about other fields?
}), gormConfig)
if err != nil {
return nil, fmt.Errorf("error creating connection to default database for %s:%d",
withDefaultDB.Host, withDefaultDB.Port)
}
createDBStatement := fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s;", mysqlConfig.DbName)
result := defaultDB.Exec(createDBStatement)
if result.Error != nil {
return nil, result.Error
}
// Now that the database should be there, try again.
return CreateMysqlDbIfNotExists(ctx, gormConfig, mysqlConfig)
}
logger.Debugf(ctx, "Error opening MySQL connection %s", err)
return nil, err
}
return db, nil
}

func isMysqlErrorWithCode(err error, code uint16) bool {
myErr := &mysql_driver.MySQLError{}
if !errors.As(err, &myErr) {
// err chain does not contain a MySQLError
return false
}

// MySQLError found in chain and set to code specified
return myErr.Number == code
}
22 changes: 22 additions & 0 deletions database/mysql_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package database

import (
"fmt"
"github.com/stretchr/testify/assert"
)
import "context"
import "testing"

func TestGettingMysqlDsn(t *testing.T) {
mysql := MysqlConfig{
Host: "some.ho.st",
Port: 3306,
DbName: "flyteadmin",
User: "user",
Password: "pass",
}
ctx := context.Background()
xx := getMysqlDsn(ctx, mysql)
fmt.Printf("connection string!!! %s", xx)
assert.Equal(t, "user:pass@tcp(some.ho.st:3306)/flyteadmin?allowCleartextPasswords=true&allowNativePasswords=false&checkConnLiveness=false&multiStatements=true&maxAllowedPacket=0", xx)
}
113 changes: 113 additions & 0 deletions database/postgres.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
package database

import (
"context"
"errors"
"fmt"
"github.com/flyteorg/flytestdlib/logger"
"github.com/jackc/pgconn"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"os"
"strings"
)

const pqInvalidDBCode = "3D000"
const pqDbAlreadyExistsCode = "42P04"
const defaultDB = "postgres"

// Resolves a password value from either a user-provided inline value or a filepath whose contents contain a password.
// Possibly postgres specific, leaving in this file for now
func resolvePassword(ctx context.Context, passwordVal, passwordPath string) string {
password := passwordVal
if len(passwordPath) > 0 {
if _, err := os.Stat(passwordPath); os.IsNotExist(err) {
logger.Fatalf(ctx,
"missing database password at specified path [%s]", passwordPath)
}
passwordVal, err := os.ReadFile(passwordPath)
if err != nil {
logger.Fatalf(ctx, "failed to read database password from path [%s] with err: %v",
passwordPath, err)
}
// Passwords can contain special characters as long as they are percent encoded
// https://www.postgresql.org/docs/current/libpq-connect.html
password = strings.TrimSpace(string(passwordVal))
}
return password
}

// Produces the DSN (data source name) for opening a postgres db connection.
func getPostgresDsn(ctx context.Context, pgConfig PostgresConfig) string {
password := resolvePassword(ctx, pgConfig.Password, pgConfig.PasswordPath)
if len(password) == 0 {
// The password-less case is included for development environments.
return fmt.Sprintf("host=%s port=%d dbname=%s user=%s sslmode=disable",
pgConfig.Host, pgConfig.Port, pgConfig.DbName, pgConfig.User)
}
return fmt.Sprintf("host=%s port=%d dbname=%s user=%s password=%s %s",
pgConfig.Host, pgConfig.Port, pgConfig.DbName, pgConfig.User, password, pgConfig.ExtraOptions)
}

func PostgresDsn(ctx context.Context, pgConfig PostgresConfig) string {
password := resolvePassword(ctx, pgConfig.Password, pgConfig.PasswordPath)
if len(password) == 0 {
// The password-less case is included for development environments.
return fmt.Sprintf("host=%s port=%d dbname=%s user=%s sslmode=disable",
pgConfig.Host, pgConfig.Port, pgConfig.DbName, pgConfig.User)
}
return fmt.Sprintf("host=%s port=%d dbname=%s user=%s password=%s %s",
pgConfig.Host, pgConfig.Port, pgConfig.DbName, pgConfig.User, password, pgConfig.ExtraOptions)
}

// CreatePostgresDbIfNotExists creates DB if it doesn't exist for the passed in config
func CreatePostgresDbIfNotExists(ctx context.Context, gormConfig *gorm.Config, pgConfig PostgresConfig) (*gorm.DB, error) {
dialector := postgres.Open(getPostgresDsn(ctx, pgConfig))
gormDb, err := gorm.Open(dialector, gormConfig)
if err == nil {
return gormDb, nil
}

if !isPgErrorWithCode(err, pqInvalidDBCode) {
return nil, err
}

logger.Warningf(ctx, "Database [%v] does not exist", pgConfig.DbName)

// Every postgres installation includes a 'postgres' database by default. We connect to that now in order to
// initialize the user-specified database.
defaultDbPgConfig := pgConfig
defaultDbPgConfig.DbName = defaultDB
defaultDBDialector := postgres.Open(getPostgresDsn(ctx, defaultDbPgConfig))
gormDb, err = gorm.Open(defaultDBDialector, gormConfig)
if err != nil {
return nil, err
}

// Because we asserted earlier that the db does not exist, we create it now.
logger.Infof(ctx, "Creating database %v", pgConfig.DbName)

// NOTE: golang sql drivers do not support parameter injection for CREATE calls
createDBStatement := fmt.Sprintf("CREATE DATABASE %s", pgConfig.DbName)
result := gormDb.Exec(createDBStatement)

if result.Error != nil {
if !isPgErrorWithCode(result.Error, pqDbAlreadyExistsCode) {
return nil, result.Error
}
logger.Warningf(ctx, "Got DB already exists error for [%s], skipping...", pgConfig.DbName)
}
// Now try connecting to the db again
return gorm.Open(dialector, gormConfig)
}

func isPgErrorWithCode(err error, code string) bool {
pgErr := &pgconn.PgError{}
if !errors.As(err, &pgErr) {
// err chain does not contain a pgconn.PgError
return false
}

// pgconn.PgError found in chain and set to code specified
return pgErr.Code == code
}
Loading