From d4b0fca45134bec52f40a7e98acd8aab97e43a7e Mon Sep 17 00:00:00 2001 From: Jean Rouge Date: Sun, 18 Oct 2020 12:11:15 -0700 Subject: [PATCH] Making lib/backend/client.getFactory public Mainly to allow writing "wrapper" custom backends, see eg https://github.com/uber/kraken/issues/278#issuecomment-705121495 Signed-off-by: Jean Rouge --- lib/backend/client.go | 5 +++-- lib/backend/client_test.go | 32 ++++++++++++++++++++++++++++++++ lib/backend/manager.go | 2 +- 3 files changed, 36 insertions(+), 3 deletions(-) create mode 100644 lib/backend/client_test.go diff --git a/lib/backend/client.go b/lib/backend/client.go index 317843ddc..c54fc3f5e 100644 --- a/lib/backend/client.go +++ b/lib/backend/client.go @@ -32,8 +32,9 @@ func Register(name string, factory ClientFactory) { _factories[name] = factory } -// getFactory returns backend client factory given client name. -func getFactory(name string) (ClientFactory, error) { +// GetFactory returns backend client factory given client name. +// This function should stay public to allow for wrapper custom backends. +func GetFactory(name string) (ClientFactory, error) { factory, ok := _factories[name] if !ok { return nil, fmt.Errorf("no backend client defined with name %s", name) diff --git a/lib/backend/client_test.go b/lib/backend/client_test.go new file mode 100644 index 000000000..0a29cff38 --- /dev/null +++ b/lib/backend/client_test.go @@ -0,0 +1,32 @@ +package backend + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRegisterAndGetFactory(t *testing.T) { + t.Run("round-trip", func(t *testing.T) { + name := "dummy" + factory := &dummyFactory{} + + Register(name, factory) + roundTrip, err := GetFactory(name) + + assert.NoError(t, err) + assert.Equal(t, factory, roundTrip) + }) + + t.Run("GetFactory errors out on missing factory", func(t *testing.T) { + _, err := GetFactory("i_dont_exist") + + assert.Error(t, err) + }) +} + +type dummyFactory struct{} + +func (f *dummyFactory) Create(config interface{}, authConfig interface{}) (Client, error) { + return nil, nil +} diff --git a/lib/backend/manager.go b/lib/backend/manager.go index 74f87bdd7..b98043bc0 100644 --- a/lib/backend/manager.go +++ b/lib/backend/manager.go @@ -62,7 +62,7 @@ func NewManager(configs []Config, auth AuthConfig) (*Manager, error) { var backendConfig interface{} for name, backendConfig = range config.Backend { // Pull the only key/value out of map } - factory, err := getFactory(name) + factory, err := GetFactory(name) if err != nil { return nil, fmt.Errorf("get backend client factory: %s", err) }