Skip to content

Commit b04f811

Browse files
committed
Add environment variable option to set postgres ssl mode
Signed-off-by: Kun Chang <[email protected]>
1 parent fc858d1 commit b04f811

File tree

3 files changed

+61
-7
lines changed

3 files changed

+61
-7
lines changed

pkg/db/v1beta1/common/const.go

+2
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,13 @@ const (
4141
PostgreSQLDBHostEnvName = "KATIB_POSTGRESQL_DB_HOST"
4242
PostgreSQLDBPortEnvName = "KATIB_POSTGRESQL_DB_PORT"
4343
PostgreSQLDatabase = "KATIB_POSTGRESQL_DB_DATABASE"
44+
PostgreSSLMode = "KATIB_POSTGRESQL_SSL_MODE"
4445

4546
DefaultPostgreSQLUser = "katib"
4647
DefaultPostgreSQLDatabase = "katib"
4748
DefaultPostgreSQLHost = "katib-postgres"
4849
DefaultPostgreSQLPort = "5432"
50+
DefaultPostgreSSLMode = "disable"
4951

5052
SkipDbInitializationEnvName = "SKIP_DB_INITIALIZATION"
5153
)

pkg/db/v1beta1/postgres/postgres.go

+4-2
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,12 @@ func getDbName() string {
4848
common.PostgreSQLDBPortEnvName, common.DefaultPostgreSQLPort)
4949
dbName := env.GetEnvOrDefault(common.PostgreSQLDatabase,
5050
common.DefaultPostgreSQLDatabase)
51+
sslMode := env.GetEnvOrDefault(common.PostgreSSLMode,
52+
common.DefaultPostgreSSLMode)
5153

5254
psqlInfo := fmt.Sprintf("host=%s port=%s user=%s "+
53-
"password=%s dbname=%s sslmode=disable",
54-
dbHost, dbPort, dbUser, dbPass, dbName)
55+
"password=%s dbname=%s sslmode=%s",
56+
dbHost, dbPort, dbUser, dbPass, dbName, sslMode)
5557

5658
return psqlInfo
5759
}

pkg/db/v1beta1/postgres/postgres_test.go

+55-5
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"testing"
2323

2424
sqlmock "github.com/DATA-DOG/go-sqlmock"
25+
"github.com/google/go-cmp/cmp"
2526
_ "github.com/lib/pq"
2627

2728
api_pb "github.com/kubeflow/katib/pkg/apis/manager/v1beta1"
@@ -129,11 +130,60 @@ func TestDeleteObservationLog(t *testing.T) {
129130
}
130131

131132
func TestGetDbName(t *testing.T) {
132-
// dbName := "root:@tcp(katib-mysql:3306)/katib?timeout=5s"
133-
dbName := "host=katib-postgres port=5432 user=katib password= dbname=katib sslmode=disable"
134-
135-
if getDbName() != dbName {
136-
t.Errorf("getDbName returns wrong value %v", getDbName())
133+
cases := map[string]struct {
134+
updateEnvs map[string]string
135+
wantName string
136+
}{
137+
"All parameters are default": {
138+
wantName: "host=katib-postgres port=5432 user=katib password= dbname=katib sslmode=disable",
139+
},
140+
"Set DB_USER": {
141+
updateEnvs: map[string]string{
142+
common.DBUserEnvName: "testUser",
143+
},
144+
wantName: "host=katib-postgres port=5432 user=testUser password= dbname=katib sslmode=disable",
145+
},
146+
"Set KATIB_POSTGRESQL_DB_HOST": {
147+
updateEnvs: map[string]string{
148+
common.PostgreSQLDBHostEnvName: "testHost",
149+
},
150+
wantName: "host=testHost port=5432 user=katib password= dbname=katib sslmode=disable",
151+
},
152+
"Set KATIB_POSTGRESQL_DB_PORT": {
153+
updateEnvs: map[string]string{
154+
common.PostgreSQLDBPortEnvName: "1234",
155+
},
156+
wantName: "host=katib-postgres port=1234 user=katib password= dbname=katib sslmode=disable",
157+
},
158+
"Set KATIB_POSTGRESQL_DB_DATABASE": {
159+
updateEnvs: map[string]string{
160+
common.PostgreSQLDatabase: "testDB",
161+
},
162+
wantName: "host=katib-postgres port=5432 user=katib password= dbname=testDB sslmode=disable",
163+
},
164+
"Set DB_PASSWORD": {
165+
updateEnvs: map[string]string{
166+
common.DBPasswordEnvName: "testPassword",
167+
},
168+
wantName: "host=katib-postgres port=5432 user=katib password=testPassword dbname=katib sslmode=disable",
169+
},
170+
"Set KATIB_POSTGRESQL_SSL_MODE": {
171+
updateEnvs: map[string]string{
172+
common.PostgreSSLMode: "require",
173+
},
174+
wantName: "host=katib-postgres port=5432 user=katib password= dbname=katib sslmode=require",
175+
},
137176
}
138177

178+
for name, tc := range cases {
179+
t.Run(name, func(t *testing.T) {
180+
for k, v := range tc.updateEnvs {
181+
t.Setenv(k, v)
182+
}
183+
gotName := getDbName()
184+
if diff := cmp.Diff(tc.wantName, gotName); len(diff) != 0 {
185+
t.Errorf("Unexpected DBName (-want,+got):\n%s", diff)
186+
}
187+
})
188+
}
139189
}

0 commit comments

Comments
 (0)