diff --git a/README.md b/README.md index 6a65c8b..4c225b6 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ If anyone decides to use this and wants to request a specific feature or even fi ## Short term TODO list: 1. [x] Change the logger (current one is a mess) -2. [ ] Finish adding tests (wip in another branch) +2. [x] Finish adding tests (initial batch of tests) 3. [ ] Finish implementing Update to existing secrets - [ ] Bonus: Create net new ones. 4. [ ] Support for namespace changes. \ No newline at end of file diff --git a/component/mounts_table_test.go b/component/mounts_table_test.go index 0818382..d2b3028 100644 --- a/component/mounts_table_test.go +++ b/component/mounts_table_test.go @@ -14,7 +14,7 @@ import ( func TestMountsTable_Pass(t *testing.T) { r := require.New(t) - t.Run("When the component is bound", func(t *testing.T) { + t.Run("When there is data to render", func(t *testing.T) { fakeTable := &componentfakes.FakeTable{} mTable := component.NewMountsTable() diff --git a/component/search_test.go b/component/search_test.go new file mode 100644 index 0000000..3437406 --- /dev/null +++ b/component/search_test.go @@ -0,0 +1,101 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package component_test + +import ( + "errors" + "testing" + + "github.com/dkyanakiev/vaulty/component" + "github.com/dkyanakiev/vaulty/component/componentfakes" + "github.com/gdamore/tcell/v2" + "github.com/rivo/tview" + "github.com/stretchr/testify/require" +) + +func TestSearch_Pass(t *testing.T) { + r := require.New(t) + + input := &componentfakes.FakeInputField{} + search := component.NewSearchField("test") + search.InputField = input + + var changedCalled bool + search.Props.ChangedFunc = func(text string) { + changedCalled = true + } + + var doneCalled bool + search.Props.DoneFunc = func(key tcell.Key) { + doneCalled = true + } + search.Bind(tview.NewFlex()) + + err := search.Render() + r.NoError(err) + + actualDoneFunc := input.SetDoneFuncArgsForCall(0) + actualChangedFunc := input.SetChangedFuncArgsForCall(0) + + actualChangedFunc("") + actualDoneFunc(tcell.KeyACK) + + r.True(changedCalled) + r.True(doneCalled) +} + +func TestSearch_Fail(t *testing.T) { + r := require.New(t) + + t.Run("When the component isn't bound", func(t *testing.T) { + input := &componentfakes.FakeInputField{} + search := component.NewSearchField("test") + search.InputField = input + search.Props.ChangedFunc = func(text string) {} + search.Props.DoneFunc = func(key tcell.Key) {} + + err := search.Render() + r.Error(err) + + // It provides the correct error message + r.EqualError(err, "component not bound") + + // It is the correct error + r.True(errors.Is(err, component.ErrComponentNotBound)) + }) + + t.Run("When DoneFunc is not set", func(t *testing.T) { + input := &componentfakes.FakeInputField{} + search := component.NewSearchField("test") + search.InputField = input + search.Props.ChangedFunc = func(text string) {} + search.Bind(tview.NewFlex()) + + err := search.Render() + r.Error(err) + + // It provides the correct error message + r.EqualError(err, "component properties not set") + + // It is the correct error + r.True(errors.Is(err, component.ErrComponentPropsNotSet)) + }) + + t.Run("When ChangedFunc is not set", func(t *testing.T) { + input := &componentfakes.FakeInputField{} + search := component.NewSearchField("test") + search.InputField = input + search.Props.DoneFunc = func(key tcell.Key) {} + search.Bind(tview.NewFlex()) + + err := search.Render() + r.Error(err) + + // It provides the correct error message + r.EqualError(err, "component properties not set") + + // It is the correct error + r.True(errors.Is(err, component.ErrComponentPropsNotSet)) + }) +} diff --git a/component/secret_obj_table_test.go b/component/secret_obj_table_test.go new file mode 100644 index 0000000..af96bc0 --- /dev/null +++ b/component/secret_obj_table_test.go @@ -0,0 +1,143 @@ +package component_test + +import ( + "encoding/json" + "testing" + + "github.com/dkyanakiev/vaulty/component" + "github.com/dkyanakiev/vaulty/component/componentfakes" + "github.com/gdamore/tcell/v2" + "github.com/hashicorp/vault/api" + "github.com/rivo/tview" + "github.com/stretchr/testify/require" +) + +func TestSecretObjTable_Pass(t *testing.T) { + r := require.New(t) + + t.Run("Render data as table", func(t *testing.T) { + fakeTable := &componentfakes.FakeTable{} + fakeTextView := &componentfakes.FakeTextView{} + st := component.NewSecretObjTable() + + st.Table = fakeTable + st.TextView = fakeTextView + st.ShowJson = false + st.Editable = false + + mockSecret := &api.Secret{ + RequestID: "mockRequestID", + LeaseID: "mockLeaseID", + LeaseDuration: 3600, + Renewable: true, + Data: map[string]interface{}{ + "data": map[string]interface{}{ + "key1": "dZpT6XnlnktMXaYF", + "key2": "10mNsYOLfd1OfohW", + }, + }, + } + + st.Props.Data = mockSecret + + st.Props.SelectPath = func(id string) {} + st.Props.HandleNoResources = func(format string, args ...interface{}) {} + slot := tview.NewFlex() + st.Bind(slot) + // It doesn't error + err := st.Render() + r.NoError(err) + + // Render header rows + renderHeaderCount := fakeTable.RenderHeaderCallCount() + r.Equal(renderHeaderCount, 1) + headers := fakeTable.RenderHeaderArgsForCall(0) + r.Equal(headers, component.SecretObjTableHeaderJobs) + + // Render rows + renderRowCallCount := fakeTable.RenderRowCallCount() + r.Equal(renderRowCallCount, 2) + + row1, index1, c1 := fakeTable.RenderRowArgsForCall(0) + row2, index2, c2 := fakeTable.RenderRowArgsForCall(1) + + expectedRow1 := []string{"key1", "dZpT6XnlnktMXaYF"} + expectedRow2 := []string{"key2", "10mNsYOLfd1OfohW"} + + r.Equal(expectedRow1, row1) + r.Equal(expectedRow2, row2) + r.Equal(index1, 1) + r.Equal(index2, 2) + r.Equal(c1, tcell.ColorYellow) + r.Equal(c2, tcell.ColorYellow) + + }) + + t.Run("Render data as json", func(t *testing.T) { + fakeTable := &componentfakes.FakeTable{} + fakeTextView := &componentfakes.FakeTextView{} + st := component.NewSecretObjTable() + + st.Table = fakeTable + st.TextView = fakeTextView + st.ShowJson = true + st.Editable = false + + mockSecret := &api.Secret{ + RequestID: "mockRequestID", + LeaseID: "mockLeaseID", + LeaseDuration: 3600, + Renewable: true, + Data: map[string]interface{}{ + "data": map[string]interface{}{ + "key1": "dZpT6XnlnktMXaYF", + "key2": "10mNsYOLfd1OfohW", + }, + }, + } + + st.Props.Data = mockSecret + correctText, _ := json.Marshal(mockSecret.Data["data"]) + + st.Props.SelectPath = func(id string) {} + st.Props.HandleNoResources = func(format string, args ...interface{}) {} + slot := tview.NewFlex() + st.Bind(slot) + // It doesn't error + err := st.Render() + r.NoError(err) + + // Renders correct text + + fakeTextView.GetTextReturns(string(correctText)) + renderedText := fakeTextView.GetText(true) + + r.Equal(string(correctText), renderedText) + }) + + t.Run("No data to render", func(t *testing.T) { + fakeTable := &componentfakes.FakeTable{} + fakeTextView := &componentfakes.FakeTextView{} + st := component.NewSecretObjTable() + + st.Table = fakeTable + st.TextView = fakeTextView + st.ShowJson = false + st.Editable = false + + st.Props.Data = nil + + var NoResourcesCalled bool + st.Props.HandleNoResources = func(format string, args ...interface{}) { + NoResourcesCalled = true + } + + slot := tview.NewFlex() + st.Bind(slot) + // It doesn't error + err := st.Render() + r.NoError(err) + + r.True(NoResourcesCalled) + }) +} diff --git a/component/secrets_table_test.go b/component/secrets_table_test.go new file mode 100644 index 0000000..cc3d7bd --- /dev/null +++ b/component/secrets_table_test.go @@ -0,0 +1,101 @@ +package component_test + +import ( + "testing" + + "github.com/dkyanakiev/vaulty/component" + "github.com/dkyanakiev/vaulty/component/componentfakes" + "github.com/dkyanakiev/vaulty/models" + "github.com/dkyanakiev/vaulty/styles" + "github.com/gdamore/tcell/v2" + "github.com/rivo/tview" + "github.com/stretchr/testify/require" +) + +func TestSecretsTable_Pass(t *testing.T) { + r := require.New(t) + t.Run("When there is data to render", func(t *testing.T) { + + fakeTable := &componentfakes.FakeTable{} + st := component.NewSecretsTable() + + st.Table = fakeTable + st.Props.Namespace = "default" + mockData := []models.SecretPath{ + { + PathName: "mockPathName1", + IsSecret: true, + }, + { + PathName: "mockPathName2", + IsSecret: false, + }, + } + + st.Props.Data = mockData + st.Props.SelectPath = func(id string) {} + st.Props.HandleNoResources = func(format string, args ...interface{}) {} + + slot := tview.NewFlex() + st.Bind(slot) + // It doesn't error + err := st.Render() + r.NoError(err) + + // Render header rows + renderHeaderCount := fakeTable.RenderHeaderCallCount() + r.Equal(renderHeaderCount, 1) + + // Correct headers + header := fakeTable.RenderHeaderArgsForCall(0) + r.Equal(component.SecretsTableHeaderJobs, header) + + // It renders the correct number of rows + renderRowCallCount := fakeTable.RenderRowCallCount() + r.Equal(renderRowCallCount, 2) + + row1, index1, c1 := fakeTable.RenderRowArgsForCall(0) + row2, index2, c2 := fakeTable.RenderRowArgsForCall(1) + expectedRow1 := []string{"mockPathName1", "true"} + expectedRow2 := []string{"mockPathName2", "false"} + + r.Equal(expectedRow1, row1) + r.Equal(expectedRow2, row2) + + r.Equal(index1, 1) + r.Equal(index2, 2) + r.Equal(c1, tcell.ColorYellow) + r.Equal(c2, tcell.ColorYellow) + + }) + + t.Run("No data to render", func(t *testing.T) { + fakeTable := &componentfakes.FakeTable{} + st := component.NewSecretsTable() + + st.Table = fakeTable + st.Props.Namespace = "default" + + st.Props.Data = nil + st.Props.SelectPath = func(id string) {} + st.Props.HandleNoResources = func(format string, args ...interface{}) {} + + var NoResourcesCalled bool + st.Props.HandleNoResources = func(format string, args ...interface{}) { + NoResourcesCalled = true + + r.Equal("%sno secrets available\n¯%s\\_( ͡• ͜ʖ ͡•)_/¯", format) + r.Len(args, 2) + r.Equal(args[0], styles.HighlightPrimaryTag) + r.Equal(args[1], styles.HighlightSecondaryTag) + } + slot := tview.NewFlex() + st.Bind(slot) + // It doesn't error + err := st.Render() + r.NoError(err) + r.True(NoResourcesCalled) + + }) + +} diff --git a/component/vaultinfo_test.go b/component/vaultinfo_test.go new file mode 100644 index 0000000..961e595 --- /dev/null +++ b/component/vaultinfo_test.go @@ -0,0 +1,44 @@ +package component_test + +import ( + "errors" + "testing" + + "github.com/dkyanakiev/vaulty/component" + "github.com/dkyanakiev/vaulty/component/componentfakes" + "github.com/rivo/tview" + "github.com/stretchr/testify/require" +) + +func TestVaultInfo_Pass(t *testing.T) { + r := require.New(t) + + textView := &componentfakes.FakeTextView{} + vaultInfo := component.NewVaultInfo() + vaultInfo.TextView = textView + + vaultInfo.Props.Info = "info" + + vaultInfo.Bind(tview.NewFlex()) + + err := vaultInfo.Render() + r.NoError(err) + + text := textView.SetTextArgsForCall(0) + r.Equal(text, "info") +} + +func TestVaultInfo_Failt(t *testing.T) { + r := require.New(t) + + textView := &componentfakes.FakeTextView{} + vaultInfo := component.NewVaultInfo() + vaultInfo.TextView = textView + vaultInfo.Props.Info = "info" + + err := vaultInfo.Render() + r.Error(err) + + r.True(errors.Is(err, component.ErrComponentNotBound)) + r.EqualError(err, "component not bound") +} diff --git a/vault/client.go b/vault/client.go index b4bb2b5..3af775a 100644 --- a/vault/client.go +++ b/vault/client.go @@ -12,7 +12,6 @@ type Client interface { Address() string } -//go:generate counterfeiter . Vault type Vault struct { vault *api.Client Client Client diff --git a/vault/kv_test.go b/vault/kv_test.go new file mode 100644 index 0000000..64e6efb --- /dev/null +++ b/vault/kv_test.go @@ -0,0 +1,50 @@ +package vault_test + +import ( + "context" + "testing" + + "github.com/dkyanakiev/vaulty/vault" + "github.com/dkyanakiev/vaulty/vault/vaultfakes" + "github.com/hashicorp/vault/api" + "github.com/stretchr/testify/assert" +) + +func TestGet(t *testing.T) { + ctx := context.Background() + path := "testpath" + + fakeKV2 := &vaultfakes.FakeKV2{} + + fakeKV2.GetReturns(&api.KVSecret{}, nil) + + v := &vault.Vault{ + KV2: fakeKV2, + } + + secret, err := v.Get(ctx, path) + + assert.NoError(t, err) + assert.NotNil(t, secret) + fakeKV2.Get(ctx, path) + +} + +func TestGetMetadata(t *testing.T) { + ctx := context.Background() + path := "testpath" + + fakeKV2 := &vaultfakes.FakeKV2{} + + fakeKV2.GetMetadataReturns(&api.KVMetadata{}, nil) + + v := &vault.Vault{ + KV2: fakeKV2, + } + + secret, err := v.GetMetadata(ctx, path) + + assert.NoError(t, err) + assert.NotNil(t, secret) + fakeKV2.GetMetadata(ctx, path) +} diff --git a/vault/logical.go b/vault/logical.go index 17ad52f..ca2461c 100644 --- a/vault/logical.go +++ b/vault/logical.go @@ -23,7 +23,8 @@ func (v *Vault) ListWithContext(ctx context.Context, path string) (*api.Secret, r.Method = http.MethodGet r.Params.Set("list", "true") - resp, err := v.vault.RawRequestWithContext(ctx, r) + // resp, err := v.vault.RawRequestWithContext(ctx, r) + resp, err := v.vault.Logical().ReadRawWithContext(ctx, path) if resp != nil { defer resp.Body.Close() } diff --git a/vault/logical_test.go b/vault/logical_test.go new file mode 100644 index 0000000..13d0615 --- /dev/null +++ b/vault/logical_test.go @@ -0,0 +1,22 @@ +package vault_test + +import ( + "testing" + + "github.com/dkyanakiev/vaulty/vault" + "github.com/dkyanakiev/vaulty/vault/vaultfakes" + "github.com/stretchr/testify/assert" +) + +func TestList(t *testing.T) { + path := "testpath" + + fakeLogical := &vaultfakes.FakeLogical{} + + v := &vault.Vault{ + Logical: fakeLogical, + } + _, err := v.Logical.List(path) + assert.NoError(t, err) + +} diff --git a/vault/mounts_test.go b/vault/mounts_test.go new file mode 100644 index 0000000..0859f3a --- /dev/null +++ b/vault/mounts_test.go @@ -0,0 +1,24 @@ +package vault_test + +import ( + "testing" + + "github.com/dkyanakiev/vaulty/vault" + "github.com/dkyanakiev/vaulty/vault/vaultfakes" + "github.com/hashicorp/vault/api" + "github.com/stretchr/testify/assert" +) + +func TestListMounts(t *testing.T) { + + fakeSys := &vaultfakes.FakeSys{} + fakeSys.ListMountsReturns(map[string]*api.MountOutput{}, nil) + + v := &vault.Vault{ + Sys: fakeSys, + } + + _, err := v.Sys.ListMounts() + assert.NoError(t, err) + +} diff --git a/vault/policy_test.go b/vault/policy_test.go new file mode 100644 index 0000000..4c08e56 --- /dev/null +++ b/vault/policy_test.go @@ -0,0 +1,36 @@ +package vault_test + +import ( + "testing" + + "github.com/dkyanakiev/vaulty/vault" + "github.com/dkyanakiev/vaulty/vault/vaultfakes" + "github.com/stretchr/testify/assert" +) + +func TestAllPolicies(t *testing.T) { + + fakeSys := &vaultfakes.FakeSys{} + + v := &vault.Vault{ + Sys: fakeSys, + } + + _, err := v.AllPolicies() + + assert.NoError(t, err) + +} + +func TestGetPolicyInfo(t *testing.T) { + + fakeSys := &vaultfakes.FakeSys{} + + v := &vault.Vault{ + Sys: fakeSys, + } + + _, err := v.GetPolicyInfo("test") + + assert.NoError(t, err) +} diff --git a/vault/secret.go b/vault/secret.go index 8ccdfca..a3981d1 100644 --- a/vault/secret.go +++ b/vault/secret.go @@ -10,40 +10,6 @@ import ( "github.com/hashicorp/vault/api" ) -// func (v *Vault) ListSecrets(path string) (*api.Secret, error) { -// // Get mount information -// mounts, err := v.vault.Sys().ListMounts() -// if err != nil { -// return nil, fmt.Errorf("unable to list mounts: %w", err) -// } - -// // Check if the mount is KV1 or KV2 -// version := mounts[path+"/"].Options["version"] -// if version == "" { -// version = "1" -// } - -// // List secrets -// var secret *api.Secret -// if version == "1" { -// secret, err = v.vault.Logical().List(path) -// } else { -// secret, err = v.vault.Logical().List(fmt.Sprintf("%s/metadata", path)) -// } -// if err != nil { -// return nil, fmt.Errorf("unable to list secrets for path %s: %w", path, err) -// } - -// // If the secret is wrapped, return the wrapped response -// if secret != nil && secret.WrapInfo != nil && secret.WrapInfo.TTL != 0 { -// // TODO: Handle this use case -// fmt.Println("Wrapped") -// // return OutputSecret(c.UI, secret) -// } - -// return secret, nil -// } - func (v *Vault) ListSecrets(path string) (*api.Secret, error) { secret, err := v.vault.Logical().List(fmt.Sprintf("%s/metadata", path)) diff --git a/vault/vaultfakes/fake_kv2.go b/vault/vaultfakes/fake_kv2.go new file mode 100644 index 0000000..38e2896 --- /dev/null +++ b/vault/vaultfakes/fake_kv2.go @@ -0,0 +1,201 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package vaultfakes + +import ( + "context" + "sync" + + "github.com/dkyanakiev/vaulty/vault" + "github.com/hashicorp/vault/api" +) + +type FakeKV2 struct { + GetStub func(context.Context, string) (*api.KVSecret, error) + getMutex sync.RWMutex + getArgsForCall []struct { + arg1 context.Context + arg2 string + } + getReturns struct { + result1 *api.KVSecret + result2 error + } + getReturnsOnCall map[int]struct { + result1 *api.KVSecret + result2 error + } + GetMetadataStub func(context.Context, string) (*api.KVMetadata, error) + getMetadataMutex sync.RWMutex + getMetadataArgsForCall []struct { + arg1 context.Context + arg2 string + } + getMetadataReturns struct { + result1 *api.KVMetadata + result2 error + } + getMetadataReturnsOnCall map[int]struct { + result1 *api.KVMetadata + result2 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeKV2) Get(arg1 context.Context, arg2 string) (*api.KVSecret, error) { + fake.getMutex.Lock() + ret, specificReturn := fake.getReturnsOnCall[len(fake.getArgsForCall)] + fake.getArgsForCall = append(fake.getArgsForCall, struct { + arg1 context.Context + arg2 string + }{arg1, arg2}) + stub := fake.GetStub + fakeReturns := fake.getReturns + fake.recordInvocation("Get", []interface{}{arg1, arg2}) + fake.getMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeKV2) GetCallCount() int { + fake.getMutex.RLock() + defer fake.getMutex.RUnlock() + return len(fake.getArgsForCall) +} + +func (fake *FakeKV2) GetCalls(stub func(context.Context, string) (*api.KVSecret, error)) { + fake.getMutex.Lock() + defer fake.getMutex.Unlock() + fake.GetStub = stub +} + +func (fake *FakeKV2) GetArgsForCall(i int) (context.Context, string) { + fake.getMutex.RLock() + defer fake.getMutex.RUnlock() + argsForCall := fake.getArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeKV2) GetReturns(result1 *api.KVSecret, result2 error) { + fake.getMutex.Lock() + defer fake.getMutex.Unlock() + fake.GetStub = nil + fake.getReturns = struct { + result1 *api.KVSecret + result2 error + }{result1, result2} +} + +func (fake *FakeKV2) GetReturnsOnCall(i int, result1 *api.KVSecret, result2 error) { + fake.getMutex.Lock() + defer fake.getMutex.Unlock() + fake.GetStub = nil + if fake.getReturnsOnCall == nil { + fake.getReturnsOnCall = make(map[int]struct { + result1 *api.KVSecret + result2 error + }) + } + fake.getReturnsOnCall[i] = struct { + result1 *api.KVSecret + result2 error + }{result1, result2} +} + +func (fake *FakeKV2) GetMetadata(arg1 context.Context, arg2 string) (*api.KVMetadata, error) { + fake.getMetadataMutex.Lock() + ret, specificReturn := fake.getMetadataReturnsOnCall[len(fake.getMetadataArgsForCall)] + fake.getMetadataArgsForCall = append(fake.getMetadataArgsForCall, struct { + arg1 context.Context + arg2 string + }{arg1, arg2}) + stub := fake.GetMetadataStub + fakeReturns := fake.getMetadataReturns + fake.recordInvocation("GetMetadata", []interface{}{arg1, arg2}) + fake.getMetadataMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeKV2) GetMetadataCallCount() int { + fake.getMetadataMutex.RLock() + defer fake.getMetadataMutex.RUnlock() + return len(fake.getMetadataArgsForCall) +} + +func (fake *FakeKV2) GetMetadataCalls(stub func(context.Context, string) (*api.KVMetadata, error)) { + fake.getMetadataMutex.Lock() + defer fake.getMetadataMutex.Unlock() + fake.GetMetadataStub = stub +} + +func (fake *FakeKV2) GetMetadataArgsForCall(i int) (context.Context, string) { + fake.getMetadataMutex.RLock() + defer fake.getMetadataMutex.RUnlock() + argsForCall := fake.getMetadataArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeKV2) GetMetadataReturns(result1 *api.KVMetadata, result2 error) { + fake.getMetadataMutex.Lock() + defer fake.getMetadataMutex.Unlock() + fake.GetMetadataStub = nil + fake.getMetadataReturns = struct { + result1 *api.KVMetadata + result2 error + }{result1, result2} +} + +func (fake *FakeKV2) GetMetadataReturnsOnCall(i int, result1 *api.KVMetadata, result2 error) { + fake.getMetadataMutex.Lock() + defer fake.getMetadataMutex.Unlock() + fake.GetMetadataStub = nil + if fake.getMetadataReturnsOnCall == nil { + fake.getMetadataReturnsOnCall = make(map[int]struct { + result1 *api.KVMetadata + result2 error + }) + } + fake.getMetadataReturnsOnCall[i] = struct { + result1 *api.KVMetadata + result2 error + }{result1, result2} +} + +func (fake *FakeKV2) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + fake.getMutex.RLock() + defer fake.getMutex.RUnlock() + fake.getMetadataMutex.RLock() + defer fake.getMetadataMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeKV2) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ vault.KV2 = new(FakeKV2) diff --git a/vault/vaultfakes/fake_logical.go b/vault/vaultfakes/fake_logical.go new file mode 100644 index 0000000..fde0167 --- /dev/null +++ b/vault/vaultfakes/fake_logical.go @@ -0,0 +1,117 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package vaultfakes + +import ( + "sync" + + "github.com/dkyanakiev/vaulty/vault" + "github.com/hashicorp/vault/api" +) + +type FakeLogical struct { + ListStub func(string) (*api.Secret, error) + listMutex sync.RWMutex + listArgsForCall []struct { + arg1 string + } + listReturns struct { + result1 *api.Secret + result2 error + } + listReturnsOnCall map[int]struct { + result1 *api.Secret + result2 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeLogical) List(arg1 string) (*api.Secret, error) { + fake.listMutex.Lock() + ret, specificReturn := fake.listReturnsOnCall[len(fake.listArgsForCall)] + fake.listArgsForCall = append(fake.listArgsForCall, struct { + arg1 string + }{arg1}) + stub := fake.ListStub + fakeReturns := fake.listReturns + fake.recordInvocation("List", []interface{}{arg1}) + fake.listMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeLogical) ListCallCount() int { + fake.listMutex.RLock() + defer fake.listMutex.RUnlock() + return len(fake.listArgsForCall) +} + +func (fake *FakeLogical) ListCalls(stub func(string) (*api.Secret, error)) { + fake.listMutex.Lock() + defer fake.listMutex.Unlock() + fake.ListStub = stub +} + +func (fake *FakeLogical) ListArgsForCall(i int) string { + fake.listMutex.RLock() + defer fake.listMutex.RUnlock() + argsForCall := fake.listArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLogical) ListReturns(result1 *api.Secret, result2 error) { + fake.listMutex.Lock() + defer fake.listMutex.Unlock() + fake.ListStub = nil + fake.listReturns = struct { + result1 *api.Secret + result2 error + }{result1, result2} +} + +func (fake *FakeLogical) ListReturnsOnCall(i int, result1 *api.Secret, result2 error) { + fake.listMutex.Lock() + defer fake.listMutex.Unlock() + fake.ListStub = nil + if fake.listReturnsOnCall == nil { + fake.listReturnsOnCall = make(map[int]struct { + result1 *api.Secret + result2 error + }) + } + fake.listReturnsOnCall[i] = struct { + result1 *api.Secret + result2 error + }{result1, result2} +} + +func (fake *FakeLogical) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + fake.listMutex.RLock() + defer fake.listMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeLogical) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ vault.Logical = new(FakeLogical) diff --git a/vault/vaultfakes/fake_sys.go b/vault/vaultfakes/fake_sys.go new file mode 100644 index 0000000..51a4200 --- /dev/null +++ b/vault/vaultfakes/fake_sys.go @@ -0,0 +1,257 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package vaultfakes + +import ( + "sync" + + "github.com/dkyanakiev/vaulty/vault" + "github.com/hashicorp/vault/api" +) + +type FakeSys struct { + GetPolicyStub func(string) (string, error) + getPolicyMutex sync.RWMutex + getPolicyArgsForCall []struct { + arg1 string + } + getPolicyReturns struct { + result1 string + result2 error + } + getPolicyReturnsOnCall map[int]struct { + result1 string + result2 error + } + ListMountsStub func() (map[string]*api.MountOutput, error) + listMountsMutex sync.RWMutex + listMountsArgsForCall []struct { + } + listMountsReturns struct { + result1 map[string]*api.MountOutput + result2 error + } + listMountsReturnsOnCall map[int]struct { + result1 map[string]*api.MountOutput + result2 error + } + ListPoliciesStub func() ([]string, error) + listPoliciesMutex sync.RWMutex + listPoliciesArgsForCall []struct { + } + listPoliciesReturns struct { + result1 []string + result2 error + } + listPoliciesReturnsOnCall map[int]struct { + result1 []string + result2 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeSys) GetPolicy(arg1 string) (string, error) { + fake.getPolicyMutex.Lock() + ret, specificReturn := fake.getPolicyReturnsOnCall[len(fake.getPolicyArgsForCall)] + fake.getPolicyArgsForCall = append(fake.getPolicyArgsForCall, struct { + arg1 string + }{arg1}) + stub := fake.GetPolicyStub + fakeReturns := fake.getPolicyReturns + fake.recordInvocation("GetPolicy", []interface{}{arg1}) + fake.getPolicyMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeSys) GetPolicyCallCount() int { + fake.getPolicyMutex.RLock() + defer fake.getPolicyMutex.RUnlock() + return len(fake.getPolicyArgsForCall) +} + +func (fake *FakeSys) GetPolicyCalls(stub func(string) (string, error)) { + fake.getPolicyMutex.Lock() + defer fake.getPolicyMutex.Unlock() + fake.GetPolicyStub = stub +} + +func (fake *FakeSys) GetPolicyArgsForCall(i int) string { + fake.getPolicyMutex.RLock() + defer fake.getPolicyMutex.RUnlock() + argsForCall := fake.getPolicyArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeSys) GetPolicyReturns(result1 string, result2 error) { + fake.getPolicyMutex.Lock() + defer fake.getPolicyMutex.Unlock() + fake.GetPolicyStub = nil + fake.getPolicyReturns = struct { + result1 string + result2 error + }{result1, result2} +} + +func (fake *FakeSys) GetPolicyReturnsOnCall(i int, result1 string, result2 error) { + fake.getPolicyMutex.Lock() + defer fake.getPolicyMutex.Unlock() + fake.GetPolicyStub = nil + if fake.getPolicyReturnsOnCall == nil { + fake.getPolicyReturnsOnCall = make(map[int]struct { + result1 string + result2 error + }) + } + fake.getPolicyReturnsOnCall[i] = struct { + result1 string + result2 error + }{result1, result2} +} + +func (fake *FakeSys) ListMounts() (map[string]*api.MountOutput, error) { + fake.listMountsMutex.Lock() + ret, specificReturn := fake.listMountsReturnsOnCall[len(fake.listMountsArgsForCall)] + fake.listMountsArgsForCall = append(fake.listMountsArgsForCall, struct { + }{}) + stub := fake.ListMountsStub + fakeReturns := fake.listMountsReturns + fake.recordInvocation("ListMounts", []interface{}{}) + fake.listMountsMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeSys) ListMountsCallCount() int { + fake.listMountsMutex.RLock() + defer fake.listMountsMutex.RUnlock() + return len(fake.listMountsArgsForCall) +} + +func (fake *FakeSys) ListMountsCalls(stub func() (map[string]*api.MountOutput, error)) { + fake.listMountsMutex.Lock() + defer fake.listMountsMutex.Unlock() + fake.ListMountsStub = stub +} + +func (fake *FakeSys) ListMountsReturns(result1 map[string]*api.MountOutput, result2 error) { + fake.listMountsMutex.Lock() + defer fake.listMountsMutex.Unlock() + fake.ListMountsStub = nil + fake.listMountsReturns = struct { + result1 map[string]*api.MountOutput + result2 error + }{result1, result2} +} + +func (fake *FakeSys) ListMountsReturnsOnCall(i int, result1 map[string]*api.MountOutput, result2 error) { + fake.listMountsMutex.Lock() + defer fake.listMountsMutex.Unlock() + fake.ListMountsStub = nil + if fake.listMountsReturnsOnCall == nil { + fake.listMountsReturnsOnCall = make(map[int]struct { + result1 map[string]*api.MountOutput + result2 error + }) + } + fake.listMountsReturnsOnCall[i] = struct { + result1 map[string]*api.MountOutput + result2 error + }{result1, result2} +} + +func (fake *FakeSys) ListPolicies() ([]string, error) { + fake.listPoliciesMutex.Lock() + ret, specificReturn := fake.listPoliciesReturnsOnCall[len(fake.listPoliciesArgsForCall)] + fake.listPoliciesArgsForCall = append(fake.listPoliciesArgsForCall, struct { + }{}) + stub := fake.ListPoliciesStub + fakeReturns := fake.listPoliciesReturns + fake.recordInvocation("ListPolicies", []interface{}{}) + fake.listPoliciesMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeSys) ListPoliciesCallCount() int { + fake.listPoliciesMutex.RLock() + defer fake.listPoliciesMutex.RUnlock() + return len(fake.listPoliciesArgsForCall) +} + +func (fake *FakeSys) ListPoliciesCalls(stub func() ([]string, error)) { + fake.listPoliciesMutex.Lock() + defer fake.listPoliciesMutex.Unlock() + fake.ListPoliciesStub = stub +} + +func (fake *FakeSys) ListPoliciesReturns(result1 []string, result2 error) { + fake.listPoliciesMutex.Lock() + defer fake.listPoliciesMutex.Unlock() + fake.ListPoliciesStub = nil + fake.listPoliciesReturns = struct { + result1 []string + result2 error + }{result1, result2} +} + +func (fake *FakeSys) ListPoliciesReturnsOnCall(i int, result1 []string, result2 error) { + fake.listPoliciesMutex.Lock() + defer fake.listPoliciesMutex.Unlock() + fake.ListPoliciesStub = nil + if fake.listPoliciesReturnsOnCall == nil { + fake.listPoliciesReturnsOnCall = make(map[int]struct { + result1 []string + result2 error + }) + } + fake.listPoliciesReturnsOnCall[i] = struct { + result1 []string + result2 error + }{result1, result2} +} + +func (fake *FakeSys) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + fake.getPolicyMutex.RLock() + defer fake.getPolicyMutex.RUnlock() + fake.listMountsMutex.RLock() + defer fake.listMountsMutex.RUnlock() + fake.listPoliciesMutex.RLock() + defer fake.listPoliciesMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeSys) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ vault.Sys = new(FakeSys) diff --git a/watcher/activity_test.go b/watcher/activity_test.go new file mode 100644 index 0000000..300d45a --- /dev/null +++ b/watcher/activity_test.go @@ -0,0 +1,34 @@ +package watcher_test + +import ( + "testing" + + "github.com/dkyanakiev/vaulty/watcher" + "github.com/stretchr/testify/require" +) + +func TestAdd(t *testing.T) { + r := require.New(t) + + activity := &watcher.ActivityPool{} + + activity.Add(make(chan struct{})) + r.Equal(len(activity.Activities), 1) + + activity.Add(make(chan struct{})) + r.Equal(len(activity.Activities), 2) +} + +func TestDeactivateAll(t *testing.T) { + r := require.New(t) + + activity := &watcher.ActivityPool{} + activity.Activities = []chan struct{}{ + make(chan struct{}, 1), + make(chan struct{}, 1), + } + + activity.DeactivateAll() + + r.Empty(activity.Activities) +} diff --git a/watcher/mounts_test.go b/watcher/mounts_test.go new file mode 100644 index 0000000..c4f9ece --- /dev/null +++ b/watcher/mounts_test.go @@ -0,0 +1,27 @@ +package watcher_test + +import ( + "testing" + "time" + + "github.com/dkyanakiev/vaulty/state" + "github.com/dkyanakiev/vaulty/watcher" + "github.com/dkyanakiev/vaulty/watcher/watcherfakes" + "github.com/stretchr/testify/assert" +) + +func TestSubscribeToMounts(t *testing.T) { + + fakeVault := &watcherfakes.FakeVault{} + state := state.New() + fakeWatcher := watcher.NewWatcher(state, fakeVault, 2*time.Second, nil) + + notifyCalled := false + notify := func() { + notifyCalled = true + } + + fakeWatcher.SubscribeToMounts(notify) + + assert.True(t, notifyCalled) +} diff --git a/watcher/watcher_test.go b/watcher/watcher_test.go new file mode 100644 index 0000000..4937cc2 --- /dev/null +++ b/watcher/watcher_test.go @@ -0,0 +1,83 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package watcher_test + +import ( + "io" + "testing" + "time" + + "github.com/dkyanakiev/vaulty/models" + "github.com/dkyanakiev/vaulty/state" + "github.com/dkyanakiev/vaulty/watcher" + "github.com/dkyanakiev/vaulty/watcher/watcherfakes" + "github.com/rs/zerolog" + + "github.com/stretchr/testify/require" +) + +func TestSubscription(t *testing.T) { + r := require.New(t) + logger := zerolog.New(io.Discard) + + vault := &watcherfakes.FakeVault{} + state := state.New() + + watcher := watcher.NewWatcher(state, vault, time.Second*2, &logger) + + var called bool + fn := func() { + called = true + } + + watcher.Subscribe(fn, "policy") + watcher.Notify("policy") + + r.True(called) + + called = false + watcher.Unsubscribe() + watcher.Notify("policy") + + r.False(called) +} + +func TestHandlerSubscription(t *testing.T) { + r := require.New(t) + logger := zerolog.New(io.Discard) + + vault := &watcherfakes.FakeVault{} + state := state.New() + + watcher := watcher.NewWatcher(state, vault, time.Second*2, &logger) + + var calledErrHandler bool + handleErr := func(_ string, _ ...interface{}) { + calledErrHandler = true + } + + var calledInfoHandler bool + handleInfo := func(_ string, _ ...interface{}) { + calledInfoHandler = true + } + + var calledFatalHandler bool + handleFatal := func(_ string, _ ...interface{}) { + calledFatalHandler = true + } + + watcher.SubscribeHandler(models.HandleError, handleErr) + watcher.SubscribeHandler(models.HandleInfo, handleInfo) + watcher.SubscribeHandler(models.HandleFatal, handleFatal) + + watcher.NotifyHandler(models.HandleError, "error") + watcher.NotifyHandler(models.HandleInfo, "info") + watcher.NotifyHandler(models.HandleFatal, "fatal") + + r.True(calledErrHandler) + r.True(calledInfoHandler) + r.True(calledFatalHandler) +} + +// TODO: Add more tests for the Watcher