diff --git a/CLAUDE.md b/CLAUDE.md index 956b8949..12936cb2 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -47,6 +47,7 @@ go test ./... - `internal/worktree/` - Git worktree operations - `internal/logging/` - Structured logging using slog - `internal/procmon/` - Database-backed process monitoring with heartbeats +- `internal/testutil/` - Shared test utilities and moq-generated mocks ## External Dependencies @@ -119,6 +120,117 @@ tail -f .co/debug.log cat .co/debug.log | jq . ``` +## Mock Generation + +The project uses [moq](https://github.com/matryer/moq) for generating test mocks. Mocks are stored in `internal/testutil/` and use the function-field pattern for easy customization per-test. + +### Installing moq + +moq is installed automatically via mise: + +```bash +mise install # Installs all tools including moq +``` + +The tool is defined in `mise.toml`: +```toml +"go:github.com/matryer/moq" = "latest" +``` + +### Regenerating Mocks + +After modifying interfaces or adding new `//go:generate` directives: + +```bash +mise run generate +``` + +This runs `go generate ./...` to regenerate all mocks. + +### Adding a New Mock + +1. Add a `//go:generate` directive to the interface file: + ```go + //go:generate moq -stub -out mypkg_mock.go . InterfaceName:InterfaceNameMock + ``` + +2. Run `mise run generate` to create the mock + +3. Use the mock in tests: + ```go + mock := &mypkg.InterfaceNameMock{ + MethodNameFunc: func(ctx context.Context, arg string) error { + return nil + }, + } + ``` + +### Available Mocks + +Mocks are generated in their respective package directories: +- `internal/git/git_mock.go` - Git CLI operations (`GitOperationsMock`) +- `internal/worktree/worktree_mock.go` - Git worktree operations (`WorktreeOperationsMock`) +- `internal/mise/mise_mock.go` - Mise tool operations (`MiseOperationsMock`) +- `internal/zellij/zellij_mock.go` - Zellij session management (`SessionManagerMock`, `SessionMock`) +- `internal/beads/beads_mock.go` - Beads CLI and reader interfaces (`BeadsCLIMock`, `BeadsReaderMock`) +- `internal/github/github_mock.go` - GitHub API client (`GitHubClientMock`) +- `internal/claude/claude_mock.go` - Claude runner (`ClaudeRunnerMock`) +- `internal/process/process_mock.go` - Process lister/killer (`ProcessListerMock`, `ProcessKillerMock`) +- `internal/task/task_mock.go` - Complexity estimator (`ComplexityEstimatorMock`) +- `internal/linear/linear_mock.go` - Linear API client (`LinearClientMock`) +- `internal/feedback/feedback_mock.go` - PR feedback processor (`FeedbackProcessorMock`) +- `internal/control/control_mock_test.go` - Orchestrator spawner, work destroyer (test-local mocks to avoid import cycle) + +### Testing Best Practices + +**Configuring mock behavior per-test:** +```go +mock := &git.GitOperationsMock{ + BranchExistsFunc: func(ctx context.Context, repoPath, branchName string) bool { + return branchName == "main" // Returns true only for "main" + }, +} +``` + +**Tracking and verifying calls:** +```go +mock := &git.GitOperationsMock{ + FetchPRRefFunc: func(ctx context.Context, repoPath string, prNumber int, localBranch string) error { + return nil + }, +} + +_ = mock.FetchPRRef(ctx, "/repo", 123, "pr-123") + +// Verify call count +calls := mock.FetchPRRefCalls() +if len(calls) != 1 { + t.Errorf("expected 1 call, got %d", len(calls)) +} + +// Verify call arguments +if calls[0].PrNumber != 123 { + t.Errorf("expected prNumber 123, got %d", calls[0].PrNumber) +} +``` + +**Nil functions return zero values:** +```go +mock := &git.GitOperationsMock{} // No functions set + +// Returns false (zero value for bool) when BranchExistsFunc is nil +mock.BranchExists(ctx, "/repo", "any") // returns false + +// Returns nil, nil when ListBranchesFunc is nil +branches, err := mock.ListBranches(ctx, "/repo") // branches=nil, err=nil +``` + +**Compile-time interface verification:** +```go +// Ensure mock implements the interface at compile time +var _ git.Operations = (*git.GitOperationsMock)(nil) +``` + ## Database Migrations The project uses a SQLite database (`tracking.db`) with schema migrations. diff --git a/internal/beads/beads_mock.go b/internal/beads/beads_mock.go new file mode 100644 index 00000000..9c7755a9 --- /dev/null +++ b/internal/beads/beads_mock.go @@ -0,0 +1,826 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package beads + +import ( + "context" + "sync" +) + +// Ensure, that BeadsCLIMock does implement CLI. +// If this is not the case, regenerate this file with moq. +var _ CLI = &BeadsCLIMock{} + +// BeadsCLIMock is a mock implementation of CLI. +// +// func TestSomethingThatUsesCLI(t *testing.T) { +// +// // make and configure a mocked CLI +// mockedCLI := &BeadsCLIMock{ +// AddCommentFunc: func(ctx context.Context, beadID string, comment string) error { +// panic("mock out the AddComment method") +// }, +// AddDependencyFunc: func(ctx context.Context, beadID string, dependsOnID string) error { +// panic("mock out the AddDependency method") +// }, +// AddLabelsFunc: func(ctx context.Context, beadID string, labels []string) error { +// panic("mock out the AddLabels method") +// }, +// CloseFunc: func(ctx context.Context, beadID string) error { +// panic("mock out the Close method") +// }, +// CreateFunc: func(ctx context.Context, opts CreateOptions) (string, error) { +// panic("mock out the Create method") +// }, +// ReopenFunc: func(ctx context.Context, beadID string) error { +// panic("mock out the Reopen method") +// }, +// SetExternalRefFunc: func(ctx context.Context, beadID string, externalRef string) error { +// panic("mock out the SetExternalRef method") +// }, +// UpdateFunc: func(ctx context.Context, beadID string, opts UpdateOptions) error { +// panic("mock out the Update method") +// }, +// } +// +// // use mockedCLI in code that requires CLI +// // and then make assertions. +// +// } +type BeadsCLIMock struct { + // AddCommentFunc mocks the AddComment method. + AddCommentFunc func(ctx context.Context, beadID string, comment string) error + + // AddDependencyFunc mocks the AddDependency method. + AddDependencyFunc func(ctx context.Context, beadID string, dependsOnID string) error + + // AddLabelsFunc mocks the AddLabels method. + AddLabelsFunc func(ctx context.Context, beadID string, labels []string) error + + // CloseFunc mocks the Close method. + CloseFunc func(ctx context.Context, beadID string) error + + // CreateFunc mocks the Create method. + CreateFunc func(ctx context.Context, opts CreateOptions) (string, error) + + // ReopenFunc mocks the Reopen method. + ReopenFunc func(ctx context.Context, beadID string) error + + // SetExternalRefFunc mocks the SetExternalRef method. + SetExternalRefFunc func(ctx context.Context, beadID string, externalRef string) error + + // UpdateFunc mocks the Update method. + UpdateFunc func(ctx context.Context, beadID string, opts UpdateOptions) error + + // calls tracks calls to the methods. + calls struct { + // AddComment holds details about calls to the AddComment method. + AddComment []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // BeadID is the beadID argument value. + BeadID string + // Comment is the comment argument value. + Comment string + } + // AddDependency holds details about calls to the AddDependency method. + AddDependency []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // BeadID is the beadID argument value. + BeadID string + // DependsOnID is the dependsOnID argument value. + DependsOnID string + } + // AddLabels holds details about calls to the AddLabels method. + AddLabels []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // BeadID is the beadID argument value. + BeadID string + // Labels is the labels argument value. + Labels []string + } + // Close holds details about calls to the Close method. + Close []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // BeadID is the beadID argument value. + BeadID string + } + // Create holds details about calls to the Create method. + Create []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Opts is the opts argument value. + Opts CreateOptions + } + // Reopen holds details about calls to the Reopen method. + Reopen []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // BeadID is the beadID argument value. + BeadID string + } + // SetExternalRef holds details about calls to the SetExternalRef method. + SetExternalRef []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // BeadID is the beadID argument value. + BeadID string + // ExternalRef is the externalRef argument value. + ExternalRef string + } + // Update holds details about calls to the Update method. + Update []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // BeadID is the beadID argument value. + BeadID string + // Opts is the opts argument value. + Opts UpdateOptions + } + } + lockAddComment sync.RWMutex + lockAddDependency sync.RWMutex + lockAddLabels sync.RWMutex + lockClose sync.RWMutex + lockCreate sync.RWMutex + lockReopen sync.RWMutex + lockSetExternalRef sync.RWMutex + lockUpdate sync.RWMutex +} + +// AddComment calls AddCommentFunc. +func (mock *BeadsCLIMock) AddComment(ctx context.Context, beadID string, comment string) error { + callInfo := struct { + Ctx context.Context + BeadID string + Comment string + }{ + Ctx: ctx, + BeadID: beadID, + Comment: comment, + } + mock.lockAddComment.Lock() + mock.calls.AddComment = append(mock.calls.AddComment, callInfo) + mock.lockAddComment.Unlock() + if mock.AddCommentFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.AddCommentFunc(ctx, beadID, comment) +} + +// AddCommentCalls gets all the calls that were made to AddComment. +// Check the length with: +// +// len(mockedCLI.AddCommentCalls()) +func (mock *BeadsCLIMock) AddCommentCalls() []struct { + Ctx context.Context + BeadID string + Comment string +} { + var calls []struct { + Ctx context.Context + BeadID string + Comment string + } + mock.lockAddComment.RLock() + calls = mock.calls.AddComment + mock.lockAddComment.RUnlock() + return calls +} + +// AddDependency calls AddDependencyFunc. +func (mock *BeadsCLIMock) AddDependency(ctx context.Context, beadID string, dependsOnID string) error { + callInfo := struct { + Ctx context.Context + BeadID string + DependsOnID string + }{ + Ctx: ctx, + BeadID: beadID, + DependsOnID: dependsOnID, + } + mock.lockAddDependency.Lock() + mock.calls.AddDependency = append(mock.calls.AddDependency, callInfo) + mock.lockAddDependency.Unlock() + if mock.AddDependencyFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.AddDependencyFunc(ctx, beadID, dependsOnID) +} + +// AddDependencyCalls gets all the calls that were made to AddDependency. +// Check the length with: +// +// len(mockedCLI.AddDependencyCalls()) +func (mock *BeadsCLIMock) AddDependencyCalls() []struct { + Ctx context.Context + BeadID string + DependsOnID string +} { + var calls []struct { + Ctx context.Context + BeadID string + DependsOnID string + } + mock.lockAddDependency.RLock() + calls = mock.calls.AddDependency + mock.lockAddDependency.RUnlock() + return calls +} + +// AddLabels calls AddLabelsFunc. +func (mock *BeadsCLIMock) AddLabels(ctx context.Context, beadID string, labels []string) error { + callInfo := struct { + Ctx context.Context + BeadID string + Labels []string + }{ + Ctx: ctx, + BeadID: beadID, + Labels: labels, + } + mock.lockAddLabels.Lock() + mock.calls.AddLabels = append(mock.calls.AddLabels, callInfo) + mock.lockAddLabels.Unlock() + if mock.AddLabelsFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.AddLabelsFunc(ctx, beadID, labels) +} + +// AddLabelsCalls gets all the calls that were made to AddLabels. +// Check the length with: +// +// len(mockedCLI.AddLabelsCalls()) +func (mock *BeadsCLIMock) AddLabelsCalls() []struct { + Ctx context.Context + BeadID string + Labels []string +} { + var calls []struct { + Ctx context.Context + BeadID string + Labels []string + } + mock.lockAddLabels.RLock() + calls = mock.calls.AddLabels + mock.lockAddLabels.RUnlock() + return calls +} + +// Close calls CloseFunc. +func (mock *BeadsCLIMock) Close(ctx context.Context, beadID string) error { + callInfo := struct { + Ctx context.Context + BeadID string + }{ + Ctx: ctx, + BeadID: beadID, + } + mock.lockClose.Lock() + mock.calls.Close = append(mock.calls.Close, callInfo) + mock.lockClose.Unlock() + if mock.CloseFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.CloseFunc(ctx, beadID) +} + +// CloseCalls gets all the calls that were made to Close. +// Check the length with: +// +// len(mockedCLI.CloseCalls()) +func (mock *BeadsCLIMock) CloseCalls() []struct { + Ctx context.Context + BeadID string +} { + var calls []struct { + Ctx context.Context + BeadID string + } + mock.lockClose.RLock() + calls = mock.calls.Close + mock.lockClose.RUnlock() + return calls +} + +// Create calls CreateFunc. +func (mock *BeadsCLIMock) Create(ctx context.Context, opts CreateOptions) (string, error) { + callInfo := struct { + Ctx context.Context + Opts CreateOptions + }{ + Ctx: ctx, + Opts: opts, + } + mock.lockCreate.Lock() + mock.calls.Create = append(mock.calls.Create, callInfo) + mock.lockCreate.Unlock() + if mock.CreateFunc == nil { + var ( + sOut string + errOut error + ) + return sOut, errOut + } + return mock.CreateFunc(ctx, opts) +} + +// CreateCalls gets all the calls that were made to Create. +// Check the length with: +// +// len(mockedCLI.CreateCalls()) +func (mock *BeadsCLIMock) CreateCalls() []struct { + Ctx context.Context + Opts CreateOptions +} { + var calls []struct { + Ctx context.Context + Opts CreateOptions + } + mock.lockCreate.RLock() + calls = mock.calls.Create + mock.lockCreate.RUnlock() + return calls +} + +// Reopen calls ReopenFunc. +func (mock *BeadsCLIMock) Reopen(ctx context.Context, beadID string) error { + callInfo := struct { + Ctx context.Context + BeadID string + }{ + Ctx: ctx, + BeadID: beadID, + } + mock.lockReopen.Lock() + mock.calls.Reopen = append(mock.calls.Reopen, callInfo) + mock.lockReopen.Unlock() + if mock.ReopenFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.ReopenFunc(ctx, beadID) +} + +// ReopenCalls gets all the calls that were made to Reopen. +// Check the length with: +// +// len(mockedCLI.ReopenCalls()) +func (mock *BeadsCLIMock) ReopenCalls() []struct { + Ctx context.Context + BeadID string +} { + var calls []struct { + Ctx context.Context + BeadID string + } + mock.lockReopen.RLock() + calls = mock.calls.Reopen + mock.lockReopen.RUnlock() + return calls +} + +// SetExternalRef calls SetExternalRefFunc. +func (mock *BeadsCLIMock) SetExternalRef(ctx context.Context, beadID string, externalRef string) error { + callInfo := struct { + Ctx context.Context + BeadID string + ExternalRef string + }{ + Ctx: ctx, + BeadID: beadID, + ExternalRef: externalRef, + } + mock.lockSetExternalRef.Lock() + mock.calls.SetExternalRef = append(mock.calls.SetExternalRef, callInfo) + mock.lockSetExternalRef.Unlock() + if mock.SetExternalRefFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.SetExternalRefFunc(ctx, beadID, externalRef) +} + +// SetExternalRefCalls gets all the calls that were made to SetExternalRef. +// Check the length with: +// +// len(mockedCLI.SetExternalRefCalls()) +func (mock *BeadsCLIMock) SetExternalRefCalls() []struct { + Ctx context.Context + BeadID string + ExternalRef string +} { + var calls []struct { + Ctx context.Context + BeadID string + ExternalRef string + } + mock.lockSetExternalRef.RLock() + calls = mock.calls.SetExternalRef + mock.lockSetExternalRef.RUnlock() + return calls +} + +// Update calls UpdateFunc. +func (mock *BeadsCLIMock) Update(ctx context.Context, beadID string, opts UpdateOptions) error { + callInfo := struct { + Ctx context.Context + BeadID string + Opts UpdateOptions + }{ + Ctx: ctx, + BeadID: beadID, + Opts: opts, + } + mock.lockUpdate.Lock() + mock.calls.Update = append(mock.calls.Update, callInfo) + mock.lockUpdate.Unlock() + if mock.UpdateFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.UpdateFunc(ctx, beadID, opts) +} + +// UpdateCalls gets all the calls that were made to Update. +// Check the length with: +// +// len(mockedCLI.UpdateCalls()) +func (mock *BeadsCLIMock) UpdateCalls() []struct { + Ctx context.Context + BeadID string + Opts UpdateOptions +} { + var calls []struct { + Ctx context.Context + BeadID string + Opts UpdateOptions + } + mock.lockUpdate.RLock() + calls = mock.calls.Update + mock.lockUpdate.RUnlock() + return calls +} + +// Ensure, that BeadsReaderMock does implement Reader. +// If this is not the case, regenerate this file with moq. +var _ Reader = &BeadsReaderMock{} + +// BeadsReaderMock is a mock implementation of Reader. +// +// func TestSomethingThatUsesReader(t *testing.T) { +// +// // make and configure a mocked Reader +// mockedReader := &BeadsReaderMock{ +// GetBeadFunc: func(ctx context.Context, id string) (*BeadWithDeps, error) { +// panic("mock out the GetBead method") +// }, +// GetBeadWithChildrenFunc: func(ctx context.Context, id string) ([]Bead, error) { +// panic("mock out the GetBeadWithChildren method") +// }, +// GetBeadsWithDepsFunc: func(ctx context.Context, beadIDs []string) (*BeadsWithDepsResult, error) { +// panic("mock out the GetBeadsWithDeps method") +// }, +// GetReadyBeadsFunc: func(ctx context.Context) ([]Bead, error) { +// panic("mock out the GetReadyBeads method") +// }, +// GetTransitiveDependenciesFunc: func(ctx context.Context, id string) ([]Bead, error) { +// panic("mock out the GetTransitiveDependencies method") +// }, +// ListBeadsFunc: func(ctx context.Context, status string) ([]Bead, error) { +// panic("mock out the ListBeads method") +// }, +// } +// +// // use mockedReader in code that requires Reader +// // and then make assertions. +// +// } +type BeadsReaderMock struct { + // GetBeadFunc mocks the GetBead method. + GetBeadFunc func(ctx context.Context, id string) (*BeadWithDeps, error) + + // GetBeadWithChildrenFunc mocks the GetBeadWithChildren method. + GetBeadWithChildrenFunc func(ctx context.Context, id string) ([]Bead, error) + + // GetBeadsWithDepsFunc mocks the GetBeadsWithDeps method. + GetBeadsWithDepsFunc func(ctx context.Context, beadIDs []string) (*BeadsWithDepsResult, error) + + // GetReadyBeadsFunc mocks the GetReadyBeads method. + GetReadyBeadsFunc func(ctx context.Context) ([]Bead, error) + + // GetTransitiveDependenciesFunc mocks the GetTransitiveDependencies method. + GetTransitiveDependenciesFunc func(ctx context.Context, id string) ([]Bead, error) + + // ListBeadsFunc mocks the ListBeads method. + ListBeadsFunc func(ctx context.Context, status string) ([]Bead, error) + + // calls tracks calls to the methods. + calls struct { + // GetBead holds details about calls to the GetBead method. + GetBead []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // ID is the id argument value. + ID string + } + // GetBeadWithChildren holds details about calls to the GetBeadWithChildren method. + GetBeadWithChildren []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // ID is the id argument value. + ID string + } + // GetBeadsWithDeps holds details about calls to the GetBeadsWithDeps method. + GetBeadsWithDeps []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // BeadIDs is the beadIDs argument value. + BeadIDs []string + } + // GetReadyBeads holds details about calls to the GetReadyBeads method. + GetReadyBeads []struct { + // Ctx is the ctx argument value. + Ctx context.Context + } + // GetTransitiveDependencies holds details about calls to the GetTransitiveDependencies method. + GetTransitiveDependencies []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // ID is the id argument value. + ID string + } + // ListBeads holds details about calls to the ListBeads method. + ListBeads []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Status is the status argument value. + Status string + } + } + lockGetBead sync.RWMutex + lockGetBeadWithChildren sync.RWMutex + lockGetBeadsWithDeps sync.RWMutex + lockGetReadyBeads sync.RWMutex + lockGetTransitiveDependencies sync.RWMutex + lockListBeads sync.RWMutex +} + +// GetBead calls GetBeadFunc. +func (mock *BeadsReaderMock) GetBead(ctx context.Context, id string) (*BeadWithDeps, error) { + callInfo := struct { + Ctx context.Context + ID string + }{ + Ctx: ctx, + ID: id, + } + mock.lockGetBead.Lock() + mock.calls.GetBead = append(mock.calls.GetBead, callInfo) + mock.lockGetBead.Unlock() + if mock.GetBeadFunc == nil { + var ( + beadWithDepsOut *BeadWithDeps + errOut error + ) + return beadWithDepsOut, errOut + } + return mock.GetBeadFunc(ctx, id) +} + +// GetBeadCalls gets all the calls that were made to GetBead. +// Check the length with: +// +// len(mockedReader.GetBeadCalls()) +func (mock *BeadsReaderMock) GetBeadCalls() []struct { + Ctx context.Context + ID string +} { + var calls []struct { + Ctx context.Context + ID string + } + mock.lockGetBead.RLock() + calls = mock.calls.GetBead + mock.lockGetBead.RUnlock() + return calls +} + +// GetBeadWithChildren calls GetBeadWithChildrenFunc. +func (mock *BeadsReaderMock) GetBeadWithChildren(ctx context.Context, id string) ([]Bead, error) { + callInfo := struct { + Ctx context.Context + ID string + }{ + Ctx: ctx, + ID: id, + } + mock.lockGetBeadWithChildren.Lock() + mock.calls.GetBeadWithChildren = append(mock.calls.GetBeadWithChildren, callInfo) + mock.lockGetBeadWithChildren.Unlock() + if mock.GetBeadWithChildrenFunc == nil { + var ( + beadsOut []Bead + errOut error + ) + return beadsOut, errOut + } + return mock.GetBeadWithChildrenFunc(ctx, id) +} + +// GetBeadWithChildrenCalls gets all the calls that were made to GetBeadWithChildren. +// Check the length with: +// +// len(mockedReader.GetBeadWithChildrenCalls()) +func (mock *BeadsReaderMock) GetBeadWithChildrenCalls() []struct { + Ctx context.Context + ID string +} { + var calls []struct { + Ctx context.Context + ID string + } + mock.lockGetBeadWithChildren.RLock() + calls = mock.calls.GetBeadWithChildren + mock.lockGetBeadWithChildren.RUnlock() + return calls +} + +// GetBeadsWithDeps calls GetBeadsWithDepsFunc. +func (mock *BeadsReaderMock) GetBeadsWithDeps(ctx context.Context, beadIDs []string) (*BeadsWithDepsResult, error) { + callInfo := struct { + Ctx context.Context + BeadIDs []string + }{ + Ctx: ctx, + BeadIDs: beadIDs, + } + mock.lockGetBeadsWithDeps.Lock() + mock.calls.GetBeadsWithDeps = append(mock.calls.GetBeadsWithDeps, callInfo) + mock.lockGetBeadsWithDeps.Unlock() + if mock.GetBeadsWithDepsFunc == nil { + var ( + beadsWithDepsResultOut *BeadsWithDepsResult + errOut error + ) + return beadsWithDepsResultOut, errOut + } + return mock.GetBeadsWithDepsFunc(ctx, beadIDs) +} + +// GetBeadsWithDepsCalls gets all the calls that were made to GetBeadsWithDeps. +// Check the length with: +// +// len(mockedReader.GetBeadsWithDepsCalls()) +func (mock *BeadsReaderMock) GetBeadsWithDepsCalls() []struct { + Ctx context.Context + BeadIDs []string +} { + var calls []struct { + Ctx context.Context + BeadIDs []string + } + mock.lockGetBeadsWithDeps.RLock() + calls = mock.calls.GetBeadsWithDeps + mock.lockGetBeadsWithDeps.RUnlock() + return calls +} + +// GetReadyBeads calls GetReadyBeadsFunc. +func (mock *BeadsReaderMock) GetReadyBeads(ctx context.Context) ([]Bead, error) { + callInfo := struct { + Ctx context.Context + }{ + Ctx: ctx, + } + mock.lockGetReadyBeads.Lock() + mock.calls.GetReadyBeads = append(mock.calls.GetReadyBeads, callInfo) + mock.lockGetReadyBeads.Unlock() + if mock.GetReadyBeadsFunc == nil { + var ( + beadsOut []Bead + errOut error + ) + return beadsOut, errOut + } + return mock.GetReadyBeadsFunc(ctx) +} + +// GetReadyBeadsCalls gets all the calls that were made to GetReadyBeads. +// Check the length with: +// +// len(mockedReader.GetReadyBeadsCalls()) +func (mock *BeadsReaderMock) GetReadyBeadsCalls() []struct { + Ctx context.Context +} { + var calls []struct { + Ctx context.Context + } + mock.lockGetReadyBeads.RLock() + calls = mock.calls.GetReadyBeads + mock.lockGetReadyBeads.RUnlock() + return calls +} + +// GetTransitiveDependencies calls GetTransitiveDependenciesFunc. +func (mock *BeadsReaderMock) GetTransitiveDependencies(ctx context.Context, id string) ([]Bead, error) { + callInfo := struct { + Ctx context.Context + ID string + }{ + Ctx: ctx, + ID: id, + } + mock.lockGetTransitiveDependencies.Lock() + mock.calls.GetTransitiveDependencies = append(mock.calls.GetTransitiveDependencies, callInfo) + mock.lockGetTransitiveDependencies.Unlock() + if mock.GetTransitiveDependenciesFunc == nil { + var ( + beadsOut []Bead + errOut error + ) + return beadsOut, errOut + } + return mock.GetTransitiveDependenciesFunc(ctx, id) +} + +// GetTransitiveDependenciesCalls gets all the calls that were made to GetTransitiveDependencies. +// Check the length with: +// +// len(mockedReader.GetTransitiveDependenciesCalls()) +func (mock *BeadsReaderMock) GetTransitiveDependenciesCalls() []struct { + Ctx context.Context + ID string +} { + var calls []struct { + Ctx context.Context + ID string + } + mock.lockGetTransitiveDependencies.RLock() + calls = mock.calls.GetTransitiveDependencies + mock.lockGetTransitiveDependencies.RUnlock() + return calls +} + +// ListBeads calls ListBeadsFunc. +func (mock *BeadsReaderMock) ListBeads(ctx context.Context, status string) ([]Bead, error) { + callInfo := struct { + Ctx context.Context + Status string + }{ + Ctx: ctx, + Status: status, + } + mock.lockListBeads.Lock() + mock.calls.ListBeads = append(mock.calls.ListBeads, callInfo) + mock.lockListBeads.Unlock() + if mock.ListBeadsFunc == nil { + var ( + beadsOut []Bead + errOut error + ) + return beadsOut, errOut + } + return mock.ListBeadsFunc(ctx, status) +} + +// ListBeadsCalls gets all the calls that were made to ListBeads. +// Check the length with: +// +// len(mockedReader.ListBeadsCalls()) +func (mock *BeadsReaderMock) ListBeadsCalls() []struct { + Ctx context.Context + Status string +} { + var calls []struct { + Ctx context.Context + Status string + } + mock.lockListBeads.RLock() + calls = mock.calls.ListBeads + mock.lockListBeads.RUnlock() + return calls +} diff --git a/internal/beads/cachemanager/mock_cache_manager.go b/internal/beads/cachemanager/mock_cache_manager.go new file mode 100644 index 00000000..6f447029 --- /dev/null +++ b/internal/beads/cachemanager/mock_cache_manager.go @@ -0,0 +1,64 @@ +package cachemanager + +import ( + "context" + "time" +) + +// CacheManagerMock is a mock implementation of CacheManager for testing. +// This uses the function-field pattern consistent with moq-generated mocks. +// Note: moq doesn't support generic interfaces, so this is hand-written. +type CacheManagerMock[K comparable, V any] struct { + GetFunc func(ctx context.Context, key K) (V, bool) + GetMultipleFunc func(ctx context.Context, keys []K) (map[K]V, bool) + GetWithRefreshFunc func(ctx context.Context, key K, ttl time.Duration) (V, bool) + SetFunc func(ctx context.Context, key K, value V, ttl time.Duration) + DeleteFunc func(ctx context.Context, keys ...K) error + FlushFunc func(ctx context.Context) error +} + +// Compile-time check that CacheManagerMock implements CacheManager. +var _ CacheManager[string, any] = (*CacheManagerMock[string, any])(nil) + +func (m *CacheManagerMock[K, V]) Get(ctx context.Context, key K) (V, bool) { + if m.GetFunc != nil { + return m.GetFunc(ctx, key) + } + var zero V + return zero, false +} + +func (m *CacheManagerMock[K, V]) GetMultiple(ctx context.Context, keys []K) (map[K]V, bool) { + if m.GetMultipleFunc != nil { + return m.GetMultipleFunc(ctx, keys) + } + return nil, false +} + +func (m *CacheManagerMock[K, V]) GetWithRefresh(ctx context.Context, key K, ttl time.Duration) (V, bool) { + if m.GetWithRefreshFunc != nil { + return m.GetWithRefreshFunc(ctx, key, ttl) + } + var zero V + return zero, false +} + +func (m *CacheManagerMock[K, V]) Set(ctx context.Context, key K, value V, ttl time.Duration) { + if m.SetFunc != nil { + m.SetFunc(ctx, key, value, ttl) + } +} + +func (m *CacheManagerMock[K, V]) Delete(ctx context.Context, keys ...K) error { + if m.DeleteFunc != nil { + return m.DeleteFunc(ctx, keys...) + } + return nil +} + +func (m *CacheManagerMock[K, V]) Flush(ctx context.Context) error { + if m.FlushFunc != nil { + return m.FlushFunc(ctx) + } + return nil +} diff --git a/internal/beads/cachemanager/mock_cache_manager_test.go b/internal/beads/cachemanager/mock_cache_manager_test.go deleted file mode 100644 index 6f0080f2..00000000 --- a/internal/beads/cachemanager/mock_cache_manager_test.go +++ /dev/null @@ -1,42 +0,0 @@ -package cachemanager - -import ( - "context" - "time" - - "github.com/stretchr/testify/mock" -) - -// MockCacheManager is a mock implementation of CacheManager for testing -type MockCacheManager[K comparable, V any] struct { - mock.Mock -} - -func (m *MockCacheManager[K, V]) Get(ctx context.Context, key K) (V, bool) { - args := m.Called(ctx, key) - return args.Get(0).(V), args.Bool(1) -} - -func (m *MockCacheManager[K, V]) GetMultiple(ctx context.Context, keys []K) (map[K]V, bool) { - args := m.Called(ctx, keys) - return args.Get(0).(map[K]V), args.Bool(1) -} - -func (m *MockCacheManager[K, V]) GetWithRefresh(ctx context.Context, key K, ttl time.Duration) (V, bool) { - args := m.Called(ctx, key, ttl) - return args.Get(0).(V), args.Bool(1) -} - -func (m *MockCacheManager[K, V]) Set(ctx context.Context, key K, value V, ttl time.Duration) { - m.Called(ctx, key, value, ttl) -} - -func (m *MockCacheManager[K, V]) Delete(ctx context.Context, keys ...K) error { - args := m.Called(ctx, keys) - return args.Error(0) -} - -func (m *MockCacheManager[K, V]) Flush(ctx context.Context) error { - args := m.Called(ctx) - return args.Error(0) -} diff --git a/internal/beads/cli.go b/internal/beads/cli.go index 481692ee..2d38069b 100644 --- a/internal/beads/cli.go +++ b/internal/beads/cli.go @@ -1,5 +1,7 @@ package beads +//go:generate moq -stub -out beads_mock.go . CLI:BeadsCLIMock Reader:BeadsReaderMock + import ( "context" ) diff --git a/internal/beads/client_test.go b/internal/beads/client_test.go new file mode 100644 index 00000000..778e14dc --- /dev/null +++ b/internal/beads/client_test.go @@ -0,0 +1,67 @@ +package beads + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// TestFlushCacheWithNilCache tests that FlushCache handles a nil cache gracefully. +func TestFlushCacheWithNilCache(t *testing.T) { + ctx := context.Background() + client := &Client{ + cache: nil, + cacheEnabled: false, + } + + // FlushCache with nil cache should not panic and return nil + err := client.FlushCache(ctx) + require.NoError(t, err) +} + +// TestBeadsWithDepsResult tests the result helper methods. +func TestBeadsWithDepsResult(t *testing.T) { + t.Run("GetBead returns BeadWithDeps for existing bead", func(t *testing.T) { + result := &BeadsWithDepsResult{ + Beads: map[string]Bead{ + "bead-1": {ID: "bead-1", Title: "Test Bead", Status: "open"}, + }, + Dependencies: map[string][]Dependency{ + "bead-1": {{IssueID: "bead-1", DependsOnID: "bead-2", Type: "blocks"}}, + }, + Dependents: map[string][]Dependent{ + "bead-1": {{IssueID: "bead-3", DependsOnID: "bead-1", Type: "blocked_by"}}, + }, + } + + beadWithDeps := result.GetBead("bead-1") + require.NotNil(t, beadWithDeps) + require.Equal(t, "bead-1", beadWithDeps.ID) + require.Equal(t, "Test Bead", beadWithDeps.Title) + require.Len(t, beadWithDeps.Dependencies, 1) + require.Len(t, beadWithDeps.Dependents, 1) + }) + + t.Run("GetBead returns nil for non-existing bead", func(t *testing.T) { + result := &BeadsWithDepsResult{ + Beads: map[string]Bead{}, + Dependencies: make(map[string][]Dependency), + Dependents: make(map[string][]Dependent), + } + + beadWithDeps := result.GetBead("nonexistent") + require.Nil(t, beadWithDeps, "expected nil for non-existing bead") + }) +} + +// TestDefaultClientConfig tests the default configuration. +func TestDefaultClientConfig(t *testing.T) { + cfg := DefaultClientConfig("/path/to/db") + + require.Equal(t, "/path/to/db", cfg.DBPath) + require.True(t, cfg.CacheEnabled, "expected CacheEnabled to be true by default") + require.Equal(t, 10*time.Minute, cfg.CacheExpiration) + require.Equal(t, 30*time.Minute, cfg.CacheCleanupTime) +} diff --git a/internal/claude/claude_mock.go b/internal/claude/claude_mock.go new file mode 100644 index 00000000..28b97a11 --- /dev/null +++ b/internal/claude/claude_mock.go @@ -0,0 +1,110 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package claude + +import ( + "context" + "github.com/newhook/co/internal/db" + "github.com/newhook/co/internal/project" + "sync" +) + +// Ensure, that ClaudeRunnerMock does implement Runner. +// If this is not the case, regenerate this file with moq. +var _ Runner = &ClaudeRunnerMock{} + +// ClaudeRunnerMock is a mock implementation of Runner. +// +// func TestSomethingThatUsesRunner(t *testing.T) { +// +// // make and configure a mocked Runner +// mockedRunner := &ClaudeRunnerMock{ +// RunFunc: func(ctx context.Context, database *db.DB, taskID string, prompt string, workDir string, cfg *project.Config) error { +// panic("mock out the Run method") +// }, +// } +// +// // use mockedRunner in code that requires Runner +// // and then make assertions. +// +// } +type ClaudeRunnerMock struct { + // RunFunc mocks the Run method. + RunFunc func(ctx context.Context, database *db.DB, taskID string, prompt string, workDir string, cfg *project.Config) error + + // calls tracks calls to the methods. + calls struct { + // Run holds details about calls to the Run method. + Run []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Database is the database argument value. + Database *db.DB + // TaskID is the taskID argument value. + TaskID string + // Prompt is the prompt argument value. + Prompt string + // WorkDir is the workDir argument value. + WorkDir string + // Cfg is the cfg argument value. + Cfg *project.Config + } + } + lockRun sync.RWMutex +} + +// Run calls RunFunc. +func (mock *ClaudeRunnerMock) Run(ctx context.Context, database *db.DB, taskID string, prompt string, workDir string, cfg *project.Config) error { + callInfo := struct { + Ctx context.Context + Database *db.DB + TaskID string + Prompt string + WorkDir string + Cfg *project.Config + }{ + Ctx: ctx, + Database: database, + TaskID: taskID, + Prompt: prompt, + WorkDir: workDir, + Cfg: cfg, + } + mock.lockRun.Lock() + mock.calls.Run = append(mock.calls.Run, callInfo) + mock.lockRun.Unlock() + if mock.RunFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.RunFunc(ctx, database, taskID, prompt, workDir, cfg) +} + +// RunCalls gets all the calls that were made to Run. +// Check the length with: +// +// len(mockedRunner.RunCalls()) +func (mock *ClaudeRunnerMock) RunCalls() []struct { + Ctx context.Context + Database *db.DB + TaskID string + Prompt string + WorkDir string + Cfg *project.Config +} { + var calls []struct { + Ctx context.Context + Database *db.DB + TaskID string + Prompt string + WorkDir string + Cfg *project.Config + } + mock.lockRun.RLock() + calls = mock.calls.Run + mock.lockRun.RUnlock() + return calls +} diff --git a/internal/claude/inline.go b/internal/claude/inline.go index c22d4e8f..752b64d9 100644 --- a/internal/claude/inline.go +++ b/internal/claude/inline.go @@ -1,5 +1,7 @@ package claude +//go:generate moq -stub -out claude_mock.go . Runner:ClaudeRunnerMock + import ( "context" "fmt" diff --git a/internal/claude/runner_test.go b/internal/claude/runner_test.go index 737ba51e..bbc48e0c 100644 --- a/internal/claude/runner_test.go +++ b/internal/claude/runner_test.go @@ -3,6 +3,8 @@ package claude import ( "strings" "testing" + + "github.com/stretchr/testify/require" ) func TestBuildLogAnalysisPrompt(t *testing.T) { @@ -100,15 +102,11 @@ func TestBuildLogAnalysisPrompt(t *testing.T) { result := BuildLogAnalysisPrompt(tt.params) for _, want := range tt.wantContains { - if !strings.Contains(result, want) { - t.Errorf("BuildLogAnalysisPrompt() missing expected content: %q\n\nGot:\n%s", want, result) - } + require.Contains(t, result, want, "BuildLogAnalysisPrompt() missing expected content") } for _, notWant := range tt.wantNotContain { - if strings.Contains(result, notWant) { - t.Errorf("BuildLogAnalysisPrompt() contains unexpected content: %q\n\nGot:\n%s", notWant, result) - } + require.NotContains(t, result, notWant, "BuildLogAnalysisPrompt() contains unexpected content") } }) } @@ -126,27 +124,13 @@ func TestLogAnalysisParams(t *testing.T) { LogContent: "content", } - if params.TaskID != "task-1" { - t.Errorf("TaskID = %s, want task-1", params.TaskID) - } - if params.WorkID != "work-1" { - t.Errorf("WorkID = %s, want work-1", params.WorkID) - } - if params.BranchName != "main" { - t.Errorf("BranchName = %s, want main", params.BranchName) - } - if params.RootIssueID != "issue-1" { - t.Errorf("RootIssueID = %s, want issue-1", params.RootIssueID) - } - if params.WorkflowName != "workflow-1" { - t.Errorf("WorkflowName = %s, want workflow-1", params.WorkflowName) - } - if params.JobName != "job-1" { - t.Errorf("JobName = %s, want job-1", params.JobName) - } - if params.LogContent != "content" { - t.Errorf("LogContent = %s, want content", params.LogContent) - } + require.Equal(t, "task-1", params.TaskID) + require.Equal(t, "work-1", params.WorkID) + require.Equal(t, "main", params.BranchName) + require.Equal(t, "issue-1", params.RootIssueID) + require.Equal(t, "workflow-1", params.WorkflowName) + require.Equal(t, "job-1", params.JobName) + require.Equal(t, "content", params.LogContent) } func TestBuildLogAnalysisPromptPriorityGuidelines(t *testing.T) { @@ -171,9 +155,7 @@ func TestBuildLogAnalysisPromptPriorityGuidelines(t *testing.T) { } for _, p := range priorities { - if !strings.Contains(result, p) { - t.Errorf("BuildLogAnalysisPrompt() missing priority guideline: %s", p) - } + require.True(t, strings.Contains(result, p), "BuildLogAnalysisPrompt() missing priority guideline: %s", p) } } @@ -191,17 +173,11 @@ func TestBuildLogAnalysisPromptBdCreateCommand(t *testing.T) { result := BuildLogAnalysisPrompt(params) // Check that bd create command format is included - if !strings.Contains(result, "bd create") { - t.Error("BuildLogAnalysisPrompt() missing bd create command") - } + require.Contains(t, result, "bd create", "BuildLogAnalysisPrompt() missing bd create command") // Check that it includes type options - if !strings.Contains(result, "--type") { - t.Error("BuildLogAnalysisPrompt() missing --type flag") - } + require.Contains(t, result, "--type", "BuildLogAnalysisPrompt() missing --type flag") // Check that it includes priority option - if !strings.Contains(result, "--priority") { - t.Error("BuildLogAnalysisPrompt() missing --priority flag") - } + require.Contains(t, result, "--priority", "BuildLogAnalysisPrompt() missing --priority flag") } diff --git a/internal/control/control_mock_test.go b/internal/control/control_mock_test.go new file mode 100644 index 00000000..eebf1638 --- /dev/null +++ b/internal/control/control_mock_test.go @@ -0,0 +1,198 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package control_test + +import ( + "context" + "github.com/newhook/co/internal/control" + "github.com/newhook/co/internal/project" + "io" + "sync" +) + +// Ensure, that OrchestratorSpawnerMock does implement control.OrchestratorSpawner. +// If this is not the case, regenerate this file with moq. +var _ control.OrchestratorSpawner = &OrchestratorSpawnerMock{} + +// OrchestratorSpawnerMock is a mock implementation of control.OrchestratorSpawner. +// +// func TestSomethingThatUsesOrchestratorSpawner(t *testing.T) { +// +// // make and configure a mocked control.OrchestratorSpawner +// mockedOrchestratorSpawner := &OrchestratorSpawnerMock{ +// SpawnWorkOrchestratorFunc: func(ctx context.Context, workID string, projectName string, workDir string, friendlyName string, w io.Writer) error { +// panic("mock out the SpawnWorkOrchestrator method") +// }, +// } +// +// // use mockedOrchestratorSpawner in code that requires control.OrchestratorSpawner +// // and then make assertions. +// +// } +type OrchestratorSpawnerMock struct { + // SpawnWorkOrchestratorFunc mocks the SpawnWorkOrchestrator method. + SpawnWorkOrchestratorFunc func(ctx context.Context, workID string, projectName string, workDir string, friendlyName string, w io.Writer) error + + // calls tracks calls to the methods. + calls struct { + // SpawnWorkOrchestrator holds details about calls to the SpawnWorkOrchestrator method. + SpawnWorkOrchestrator []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // WorkID is the workID argument value. + WorkID string + // ProjectName is the projectName argument value. + ProjectName string + // WorkDir is the workDir argument value. + WorkDir string + // FriendlyName is the friendlyName argument value. + FriendlyName string + // W is the w argument value. + W io.Writer + } + } + lockSpawnWorkOrchestrator sync.RWMutex +} + +// SpawnWorkOrchestrator calls SpawnWorkOrchestratorFunc. +func (mock *OrchestratorSpawnerMock) SpawnWorkOrchestrator(ctx context.Context, workID string, projectName string, workDir string, friendlyName string, w io.Writer) error { + callInfo := struct { + Ctx context.Context + WorkID string + ProjectName string + WorkDir string + FriendlyName string + W io.Writer + }{ + Ctx: ctx, + WorkID: workID, + ProjectName: projectName, + WorkDir: workDir, + FriendlyName: friendlyName, + W: w, + } + mock.lockSpawnWorkOrchestrator.Lock() + mock.calls.SpawnWorkOrchestrator = append(mock.calls.SpawnWorkOrchestrator, callInfo) + mock.lockSpawnWorkOrchestrator.Unlock() + if mock.SpawnWorkOrchestratorFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.SpawnWorkOrchestratorFunc(ctx, workID, projectName, workDir, friendlyName, w) +} + +// SpawnWorkOrchestratorCalls gets all the calls that were made to SpawnWorkOrchestrator. +// Check the length with: +// +// len(mockedOrchestratorSpawner.SpawnWorkOrchestratorCalls()) +func (mock *OrchestratorSpawnerMock) SpawnWorkOrchestratorCalls() []struct { + Ctx context.Context + WorkID string + ProjectName string + WorkDir string + FriendlyName string + W io.Writer +} { + var calls []struct { + Ctx context.Context + WorkID string + ProjectName string + WorkDir string + FriendlyName string + W io.Writer + } + mock.lockSpawnWorkOrchestrator.RLock() + calls = mock.calls.SpawnWorkOrchestrator + mock.lockSpawnWorkOrchestrator.RUnlock() + return calls +} + +// Ensure, that WorkDestroyerMock does implement control.WorkDestroyer. +// If this is not the case, regenerate this file with moq. +var _ control.WorkDestroyer = &WorkDestroyerMock{} + +// WorkDestroyerMock is a mock implementation of control.WorkDestroyer. +// +// func TestSomethingThatUsesWorkDestroyer(t *testing.T) { +// +// // make and configure a mocked control.WorkDestroyer +// mockedWorkDestroyer := &WorkDestroyerMock{ +// DestroyWorkFunc: func(ctx context.Context, proj *project.Project, workID string, w io.Writer) error { +// panic("mock out the DestroyWork method") +// }, +// } +// +// // use mockedWorkDestroyer in code that requires control.WorkDestroyer +// // and then make assertions. +// +// } +type WorkDestroyerMock struct { + // DestroyWorkFunc mocks the DestroyWork method. + DestroyWorkFunc func(ctx context.Context, proj *project.Project, workID string, w io.Writer) error + + // calls tracks calls to the methods. + calls struct { + // DestroyWork holds details about calls to the DestroyWork method. + DestroyWork []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Proj is the proj argument value. + Proj *project.Project + // WorkID is the workID argument value. + WorkID string + // W is the w argument value. + W io.Writer + } + } + lockDestroyWork sync.RWMutex +} + +// DestroyWork calls DestroyWorkFunc. +func (mock *WorkDestroyerMock) DestroyWork(ctx context.Context, proj *project.Project, workID string, w io.Writer) error { + callInfo := struct { + Ctx context.Context + Proj *project.Project + WorkID string + W io.Writer + }{ + Ctx: ctx, + Proj: proj, + WorkID: workID, + W: w, + } + mock.lockDestroyWork.Lock() + mock.calls.DestroyWork = append(mock.calls.DestroyWork, callInfo) + mock.lockDestroyWork.Unlock() + if mock.DestroyWorkFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.DestroyWorkFunc(ctx, proj, workID, w) +} + +// DestroyWorkCalls gets all the calls that were made to DestroyWork. +// Check the length with: +// +// len(mockedWorkDestroyer.DestroyWorkCalls()) +func (mock *WorkDestroyerMock) DestroyWorkCalls() []struct { + Ctx context.Context + Proj *project.Project + WorkID string + W io.Writer +} { + var calls []struct { + Ctx context.Context + Proj *project.Project + WorkID string + W io.Writer + } + mock.lockDestroyWork.RLock() + calls = mock.calls.DestroyWork + mock.lockDestroyWork.RUnlock() + return calls +} diff --git a/internal/control/control_test.go b/internal/control/control_test.go new file mode 100644 index 00000000..97397029 --- /dev/null +++ b/internal/control/control_test.go @@ -0,0 +1,810 @@ +package control_test + +import ( + "context" + "errors" + "io" + "testing" + "time" + + "github.com/newhook/co/internal/control" + "github.com/newhook/co/internal/db" + "github.com/newhook/co/internal/feedback" + "github.com/newhook/co/internal/git" + "github.com/newhook/co/internal/mise" + "github.com/newhook/co/internal/project" + "github.com/newhook/co/internal/worktree" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// setupTestProject creates a minimal test project with an in-memory database. +func setupTestProject(t *testing.T) (*project.Project, func()) { + t.Helper() + ctx := context.Background() + + database, err := db.OpenPath(ctx, ":memory:") + require.NoError(t, err, "failed to open database") + + cfg := &project.Config{ + Project: project.ProjectConfig{ + Name: "test-project", + CreatedAt: time.Now(), + }, + Repo: project.RepoConfig{ + BaseBranch: "main", + }, + // SchedulerConfig will use defaults + } + + proj := &project.Project{ + Root: "/tmp/test-project", + Config: cfg, + DB: database, + } + + cleanup := func() { + database.Close() + } + + return proj, cleanup +} + +// testMocks holds all mocked dependencies for ControlPlane tests. +type testMocks struct { + CP *control.ControlPlane + Git *git.GitOperationsMock + Worktree *worktree.WorktreeOperationsMock + Feedback *feedback.FeedbackProcessorMock + Spawner *OrchestratorSpawnerMock + Destroyer *WorkDestroyerMock +} + +// setupControlPlane creates a ControlPlane with all mocked dependencies. +func setupControlPlane() *testMocks { + gitMock := &git.GitOperationsMock{} + wtMock := &worktree.WorktreeOperationsMock{} + miseMock := &mise.MiseOperationsMock{} + feedbackMock := &feedback.FeedbackProcessorMock{} + spawnerMock := &OrchestratorSpawnerMock{} + destroyerMock := &WorkDestroyerMock{} + + cp := control.NewControlPlaneWithDeps( + gitMock, + wtMock, + nil, // zellij not used in these tests + func(dir string) mise.Operations { return miseMock }, + feedbackMock, + spawnerMock, + destroyerMock, + ) + + return &testMocks{ + CP: cp, + Git: gitMock, + Worktree: wtMock, + Feedback: feedbackMock, + Spawner: spawnerMock, + Destroyer: destroyerMock, + } +} + +// createTestWork creates a work record for testing with minimal required fields. +func createTestWork(ctx context.Context, t *testing.T, database *db.DB, workID, branchName, rootIssueID string) { + t.Helper() + err := database.CreateWork(ctx, workID, workID, "", branchName, "main", rootIssueID, false) + require.NoError(t, err) +} + +func TestHandleGitPushTask(t *testing.T) { + ctx := context.Background() + proj, cleanup := setupTestProject(t) + defer cleanup() + + t.Run("succeeds with metadata", func(t *testing.T) { + mocks := setupControlPlane() + + // Configure git mock to succeed + mocks.Git.PushSetUpstreamFunc = func(ctx context.Context, branch, dir string) error { + return nil + } + + // Create work for the task + createTestWork(ctx, t, proj.DB, "w-test", "feature-branch", "root-issue-1") + defer proj.DB.DeleteWork(ctx, "w-test") + + task := &db.ScheduledTask{ + ID: "task-1", + WorkID: "w-test", + TaskType: db.TaskTypeGitPush, + Metadata: map[string]string{ + "branch": "feature-branch", + "dir": "/work/tree", + }, + } + + err := mocks.CP.HandleGitPushTask(ctx, proj, task) + require.NoError(t, err) + + // Verify git push was called + calls := mocks.Git.PushSetUpstreamCalls() + require.Len(t, calls, 1) + assert.Equal(t, "feature-branch", calls[0].Branch) + assert.Equal(t, "/work/tree", calls[0].Dir) + }) + + t.Run("uses work info when metadata empty", func(t *testing.T) { + mocks := setupControlPlane() + + // Configure git mock + mocks.Git.PushSetUpstreamFunc = func(ctx context.Context, branch, dir string) error { + return nil + } + + // Create work with worktree path + createTestWork(ctx, t, proj.DB, "w-test2", "from-work-branch", "root-issue-1") + err := proj.DB.UpdateWorkWorktreePath(ctx, "w-test2", "/from/work/path") + require.NoError(t, err) + defer proj.DB.DeleteWork(ctx, "w-test2") + + task := &db.ScheduledTask{ + ID: "task-2", + WorkID: "w-test2", + TaskType: db.TaskTypeGitPush, + Metadata: map[string]string{}, // Empty metadata + } + + err = mocks.CP.HandleGitPushTask(ctx, proj, task) + require.NoError(t, err) + + // Verify it used work's branch and path + calls := mocks.Git.PushSetUpstreamCalls() + require.Len(t, calls, 1) + assert.Equal(t, "from-work-branch", calls[0].Branch) + assert.Equal(t, "/from/work/path", calls[0].Dir) + }) + + t.Run("returns error when git push fails", func(t *testing.T) { + mocks := setupControlPlane() + + mocks.Git.PushSetUpstreamFunc = func(ctx context.Context, branch, dir string) error { + return errors.New("push failed: authentication error") + } + + task := &db.ScheduledTask{ + ID: "task-3", + WorkID: "w-test", + TaskType: db.TaskTypeGitPush, + Metadata: map[string]string{ + "branch": "branch", + "dir": "/dir", + }, + } + + err := mocks.CP.HandleGitPushTask(ctx, proj, task) + require.Error(t, err) + assert.Contains(t, err.Error(), "push failed") + }) + + t.Run("returns error when no branch or dir", func(t *testing.T) { + mocks := setupControlPlane() + + task := &db.ScheduledTask{ + ID: "task-4", + WorkID: "nonexistent-work", + TaskType: db.TaskTypeGitPush, + Metadata: map[string]string{}, + } + + err := mocks.CP.HandleGitPushTask(ctx, proj, task) + require.Error(t, err) + assert.Contains(t, err.Error(), "work not found") + }) +} + +func TestHandleSpawnOrchestratorTask(t *testing.T) { + ctx := context.Background() + proj, cleanup := setupTestProject(t) + defer cleanup() + + t.Run("succeeds when work exists", func(t *testing.T) { + mocks := setupControlPlane() + + mocks.Spawner.SpawnWorkOrchestratorFunc = func(ctx context.Context, workID, projectName, workDir, friendlyName string, w io.Writer) error { + return nil + } + + // Create work + createTestWork(ctx, t, proj.DB, "w-spawn", "spawn-branch", "root-1") + err := proj.DB.UpdateWorkWorktreePath(ctx, "w-spawn", "/spawn/tree") + require.NoError(t, err) + defer proj.DB.DeleteWork(ctx, "w-spawn") + + task := &db.ScheduledTask{ + ID: "spawn-task-1", + WorkID: "w-spawn", + TaskType: db.TaskTypeSpawnOrchestrator, + Metadata: map[string]string{ + "worker_name": "test-worker", + }, + } + + err = mocks.CP.HandleSpawnOrchestratorTask(ctx, proj, task) + require.NoError(t, err) + + // Verify spawner was called with correct args + calls := mocks.Spawner.SpawnWorkOrchestratorCalls() + require.Len(t, calls, 1) + assert.Equal(t, "w-spawn", calls[0].WorkID) + assert.Equal(t, "test-project", calls[0].ProjectName) + assert.Equal(t, "/spawn/tree", calls[0].WorkDir) + assert.Equal(t, "test-worker", calls[0].FriendlyName) + }) + + t.Run("succeeds when work deleted", func(t *testing.T) { + mocks := setupControlPlane() + + // Work doesn't exist - task should complete without error + task := &db.ScheduledTask{ + ID: "spawn-task-2", + WorkID: "nonexistent", + TaskType: db.TaskTypeSpawnOrchestrator, + Metadata: map[string]string{}, + } + + err := mocks.CP.HandleSpawnOrchestratorTask(ctx, proj, task) + require.NoError(t, err) + + // Spawner should not have been called + assert.Len(t, mocks.Spawner.SpawnWorkOrchestratorCalls(), 0) + }) + + t.Run("returns error when no worktree path", func(t *testing.T) { + mocks := setupControlPlane() + + // Create work without worktree path + createTestWork(ctx, t, proj.DB, "w-no-tree", "branch", "root-1") + defer proj.DB.DeleteWork(ctx, "w-no-tree") + + task := &db.ScheduledTask{ + ID: "spawn-task-3", + WorkID: "w-no-tree", + TaskType: db.TaskTypeSpawnOrchestrator, + Metadata: map[string]string{}, + } + + err := mocks.CP.HandleSpawnOrchestratorTask(ctx, proj, task) + require.Error(t, err) + assert.Contains(t, err.Error(), "no worktree path") + }) + + t.Run("returns error when spawner fails", func(t *testing.T) { + mocks := setupControlPlane() + + mocks.Spawner.SpawnWorkOrchestratorFunc = func(ctx context.Context, workID, projectName, workDir, friendlyName string, w io.Writer) error { + return errors.New("zellij error") + } + + createTestWork(ctx, t, proj.DB, "w-fail", "branch", "root-1") + err := proj.DB.UpdateWorkWorktreePath(ctx, "w-fail", "/fail/tree") + require.NoError(t, err) + defer proj.DB.DeleteWork(ctx, "w-fail") + + task := &db.ScheduledTask{ + ID: "spawn-task-4", + WorkID: "w-fail", + TaskType: db.TaskTypeSpawnOrchestrator, + Metadata: map[string]string{}, + } + + err = mocks.CP.HandleSpawnOrchestratorTask(ctx, proj, task) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to spawn orchestrator") + }) +} + +func TestHandleDestroyWorktreeTask(t *testing.T) { + ctx := context.Background() + proj, cleanup := setupTestProject(t) + defer cleanup() + + t.Run("succeeds when work exists", func(t *testing.T) { + mocks := setupControlPlane() + + mocks.Destroyer.DestroyWorkFunc = func(ctx context.Context, proj *project.Project, workID string, w io.Writer) error { + return nil + } + + // Create work + createTestWork(ctx, t, proj.DB, "w-destroy", "destroy-branch", "root-1") + // Note: Normally work would be deleted by the handler, but we use a mock + defer proj.DB.DeleteWork(ctx, "w-destroy") + + task := &db.ScheduledTask{ + ID: "destroy-task-1", + WorkID: "w-destroy", + TaskType: db.TaskTypeDestroyWorktree, + } + + err := mocks.CP.HandleDestroyWorktreeTask(ctx, proj, task) + require.NoError(t, err) + + // Verify destroyer was called + calls := mocks.Destroyer.DestroyWorkCalls() + require.Len(t, calls, 1) + assert.Equal(t, "w-destroy", calls[0].WorkID) + }) + + t.Run("succeeds when work already deleted", func(t *testing.T) { + mocks := setupControlPlane() + + // Work doesn't exist - task should complete without error + task := &db.ScheduledTask{ + ID: "destroy-task-2", + WorkID: "nonexistent", + TaskType: db.TaskTypeDestroyWorktree, + } + + err := mocks.CP.HandleDestroyWorktreeTask(ctx, proj, task) + require.NoError(t, err) + + // Destroyer should not have been called + assert.Len(t, mocks.Destroyer.DestroyWorkCalls(), 0) + }) + + t.Run("returns error when destroyer fails", func(t *testing.T) { + mocks := setupControlPlane() + + mocks.Destroyer.DestroyWorkFunc = func(ctx context.Context, proj *project.Project, workID string, w io.Writer) error { + return errors.New("filesystem error") + } + + createTestWork(ctx, t, proj.DB, "w-fail-destroy", "branch", "root-1") + defer proj.DB.DeleteWork(ctx, "w-fail-destroy") + + task := &db.ScheduledTask{ + ID: "destroy-task-3", + WorkID: "w-fail-destroy", + TaskType: db.TaskTypeDestroyWorktree, + } + + err := mocks.CP.HandleDestroyWorktreeTask(ctx, proj, task) + require.Error(t, err) + assert.Contains(t, err.Error(), "filesystem error") + }) +} + +func TestHandlePRFeedbackTask(t *testing.T) { + ctx := context.Background() + proj, cleanup := setupTestProject(t) + defer cleanup() + + t.Run("processes feedback when PR exists", func(t *testing.T) { + mocks := setupControlPlane() + + mocks.Feedback.ProcessPRFeedbackFunc = func(ctx context.Context, proj *project.Project, database *db.DB, workID string) (int, error) { + return 3, nil // Created 3 beads + } + + // Create work with PR URL + createTestWork(ctx, t, proj.DB, "w-feedback", "feedback-branch", "root-1") + err := proj.DB.SetWorkPRURLAndScheduleFeedback(ctx, "w-feedback", "https://github.com/org/repo/pull/123", 5*time.Minute, 5*time.Minute) + require.NoError(t, err) + defer proj.DB.DeleteWork(ctx, "w-feedback") + + task := &db.ScheduledTask{ + ID: "feedback-task-1", + WorkID: "w-feedback", + TaskType: db.TaskTypePRFeedback, + } + + err = mocks.CP.HandlePRFeedbackTask(ctx, proj, task) + require.NoError(t, err) + + // Verify feedback processor was called + calls := mocks.Feedback.ProcessPRFeedbackCalls() + require.Len(t, calls, 1) + assert.Equal(t, "w-feedback", calls[0].WorkID) + }) + + t.Run("skips processing when no PR URL", func(t *testing.T) { + mocks := setupControlPlane() + + // Create work without PR URL + createTestWork(ctx, t, proj.DB, "w-no-pr", "no-pr-branch", "root-1") + defer proj.DB.DeleteWork(ctx, "w-no-pr") + + task := &db.ScheduledTask{ + ID: "feedback-task-2", + WorkID: "w-no-pr", + TaskType: db.TaskTypePRFeedback, + } + + err := mocks.CP.HandlePRFeedbackTask(ctx, proj, task) + require.NoError(t, err) + + // Feedback processor should not have been called + assert.Len(t, mocks.Feedback.ProcessPRFeedbackCalls(), 0) + }) + + t.Run("returns error when feedback processing fails", func(t *testing.T) { + mocks := setupControlPlane() + + mocks.Feedback.ProcessPRFeedbackFunc = func(ctx context.Context, proj *project.Project, database *db.DB, workID string) (int, error) { + return 0, errors.New("GitHub API error") + } + + createTestWork(ctx, t, proj.DB, "w-fail-fb", "branch", "root-1") + err := proj.DB.SetWorkPRURLAndScheduleFeedback(ctx, "w-fail-fb", "https://github.com/org/repo/pull/456", 5*time.Minute, 5*time.Minute) + require.NoError(t, err) + defer proj.DB.DeleteWork(ctx, "w-fail-fb") + + task := &db.ScheduledTask{ + ID: "feedback-task-3", + WorkID: "w-fail-fb", + TaskType: db.TaskTypePRFeedback, + } + + err = mocks.CP.HandlePRFeedbackTask(ctx, proj, task) + require.Error(t, err) + assert.Contains(t, err.Error(), "GitHub API error") + }) +} + +func TestGetTaskHandlers(t *testing.T) { + mocks := setupControlPlane() + + handlers := mocks.CP.GetTaskHandlers() + + // Verify all expected task types have handlers + expectedTypes := []string{ + db.TaskTypeCreateWorktree, + db.TaskTypeSpawnOrchestrator, + db.TaskTypePRFeedback, + db.TaskTypeGitPush, + db.TaskTypeDestroyWorktree, + db.TaskTypeImportPR, + db.TaskTypeCommentResolution, + db.TaskTypeGitHubComment, + db.TaskTypeGitHubResolveThread, + } + + for _, taskType := range expectedTypes { + _, ok := handlers[taskType] + assert.True(t, ok, "expected handler for task type %s", taskType) + } +} + +func TestNewControlPlane(t *testing.T) { + cp := control.NewControlPlane() + require.NotNil(t, cp) + + // Verify default dependencies are set + assert.NotNil(t, cp.Git) + assert.NotNil(t, cp.Worktree) + assert.NotNil(t, cp.Zellij) + assert.NotNil(t, cp.Mise) + assert.NotNil(t, cp.FeedbackProcessor) + assert.NotNil(t, cp.OrchestratorSpawner) + assert.NotNil(t, cp.WorkDestroyer) +} + +func TestDefaultOrchestratorSpawner(t *testing.T) { + // Compile-time check that DefaultOrchestratorSpawner implements OrchestratorSpawner + var _ control.OrchestratorSpawner = (*control.DefaultOrchestratorSpawner)(nil) +} + +func TestDefaultWorkDestroyer(t *testing.T) { + // Compile-time check that DefaultWorkDestroyer implements WorkDestroyer + var _ control.WorkDestroyer = (*control.DefaultWorkDestroyer)(nil) +} + +func TestHandleCreateWorktreeTask(t *testing.T) { + ctx := context.Background() + proj, cleanup := setupTestProject(t) + defer cleanup() + + t.Run("succeeds when work is deleted", func(t *testing.T) { + mocks := setupControlPlane() + + // Work doesn't exist - should complete without error + task := &db.ScheduledTask{ + ID: "create-task-3", + WorkID: "nonexistent", + TaskType: db.TaskTypeCreateWorktree, + Metadata: map[string]string{ + "branch": "some-branch", + }, + } + + err := mocks.CP.HandleCreateWorktreeTask(ctx, proj, task) + require.NoError(t, err) + }) + + t.Run("skips worktree creation when already exists", func(t *testing.T) { + mocks := setupControlPlane() + + mocks.Git.PushSetUpstreamFunc = func(ctx context.Context, branch, dir string) error { + return nil + } + + // Create work with existing worktree path + createTestWork(ctx, t, proj.DB, "w-exists", "exists-branch", "root-1") + err := proj.DB.UpdateWorkWorktreePath(ctx, "w-exists", "/existing/path") + require.NoError(t, err) + defer proj.DB.DeleteWork(ctx, "w-exists") + + task := &db.ScheduledTask{ + ID: "create-task-6", + WorkID: "w-exists", + TaskType: db.TaskTypeCreateWorktree, + Metadata: map[string]string{ + "branch": "exists-branch", + }, + } + + err = mocks.CP.HandleCreateWorktreeTask(ctx, proj, task) + require.NoError(t, err) + + // Worktree creation should not have been called + assert.Len(t, mocks.Worktree.CreateCalls(), 0) + assert.Len(t, mocks.Worktree.CreateFromExistingCalls(), 0) + }) + + t.Run("uses default base branch from config", func(t *testing.T) { + mocks := setupControlPlane() + + mocks.Git.PushSetUpstreamFunc = func(ctx context.Context, branch, dir string) error { + return nil + } + + // Create work with existing worktree path - allows us to test config lookup without filesystem ops + createTestWork(ctx, t, proj.DB, "w-default-base", "default-branch", "root-1") + err := proj.DB.UpdateWorkWorktreePath(ctx, "w-default-base", "/work/path") + require.NoError(t, err) + defer proj.DB.DeleteWork(ctx, "w-default-base") + + task := &db.ScheduledTask{ + ID: "create-task-7", + WorkID: "w-default-base", + TaskType: db.TaskTypeCreateWorktree, + Metadata: map[string]string{ + "branch": "default-branch", + // No base_branch in metadata - should use config default + }, + } + + err = mocks.CP.HandleCreateWorktreeTask(ctx, proj, task) + require.NoError(t, err) + + // Should not try to create worktree since it already exists + assert.Len(t, mocks.Worktree.CreateCalls(), 0) + }) +} + +func TestScheduleDestroyWorktree(t *testing.T) { + ctx := context.Background() + proj, cleanup := setupTestProject(t) + defer cleanup() + + t.Run("schedules destroy task successfully", func(t *testing.T) { + // Create work first + createTestWork(ctx, t, proj.DB, "w-sched-destroy", "branch", "root-1") + defer proj.DB.DeleteWork(ctx, "w-sched-destroy") + + err := control.ScheduleDestroyWorktree(ctx, proj, "w-sched-destroy") + require.NoError(t, err) + + // Verify task was scheduled + task, err := proj.DB.GetNextScheduledTask(ctx) + require.NoError(t, err) + require.NotNil(t, task) + assert.Equal(t, db.TaskTypeDestroyWorktree, task.TaskType) + assert.Equal(t, "w-sched-destroy", task.WorkID) + }) +} + +func TestTriggerPRFeedbackCheck(t *testing.T) { + ctx := context.Background() + proj, cleanup := setupTestProject(t) + defer cleanup() + + t.Run("triggers immediate feedback check", func(t *testing.T) { + // Create work with existing PR feedback task + createTestWork(ctx, t, proj.DB, "w-trigger", "branch", "root-1") + defer proj.DB.DeleteWork(ctx, "w-trigger") + + // First schedule a task for later + _, err := proj.DB.ScheduleTask(ctx, "w-trigger", db.TaskTypePRFeedback, time.Now().Add(1*time.Hour), nil) + require.NoError(t, err) + + // Trigger immediate check + err = control.TriggerPRFeedbackCheck(ctx, proj, "w-trigger") + require.NoError(t, err) + + // Verify the task's scheduled_at was updated to now (within tolerance) + task, err := proj.DB.GetNextScheduledTask(ctx) + require.NoError(t, err) + require.NotNil(t, task) + assert.Equal(t, db.TaskTypePRFeedback, task.TaskType) + // The task should be due now, not in an hour + assert.True(t, task.ScheduledAt.Before(time.Now().Add(1*time.Minute))) + }) +} + +func TestProcessAllDueTasksWithControlPlane(t *testing.T) { + ctx := context.Background() + proj, cleanup := setupTestProject(t) + defer cleanup() + + t.Run("processes due tasks and handles completion", func(t *testing.T) { + mocks := setupControlPlane() + + mocks.Git.PushSetUpstreamFunc = func(ctx context.Context, branch, dir string) error { + return nil + } + + // Create work and schedule a git push task + createTestWork(ctx, t, proj.DB, "w-process", "branch", "root-1") + defer proj.DB.DeleteWork(ctx, "w-process") + + _, err := proj.DB.ScheduleTask(ctx, "w-process", db.TaskTypeGitPush, time.Now(), map[string]string{ + "branch": "branch", + "dir": "/work/dir", + }) + require.NoError(t, err) + + // Process tasks + control.ProcessAllDueTasksWithControlPlane(ctx, proj, mocks.CP) + + // Verify git push was called + calls := mocks.Git.PushSetUpstreamCalls() + require.Len(t, calls, 1) + }) + + t.Run("handles unknown task type", func(t *testing.T) { + mocks := setupControlPlane() + + createTestWork(ctx, t, proj.DB, "w-unknown", "branch", "root-1") + defer proj.DB.DeleteWork(ctx, "w-unknown") + + // Schedule a task with unknown type + _, err := proj.DB.ScheduleTask(ctx, "w-unknown", "unknown_task_type", time.Now(), nil) + require.NoError(t, err) + + // Process tasks - should handle gracefully + control.ProcessAllDueTasksWithControlPlane(ctx, proj, mocks.CP) + + // No panic or error expected + }) + + t.Run("handles task failure with retry", func(t *testing.T) { + mocks := setupControlPlane() + + mocks.Git.PushSetUpstreamFunc = func(ctx context.Context, branch, dir string) error { + return errors.New("transient error") + } + + createTestWork(ctx, t, proj.DB, "w-retry", "branch", "root-1") + defer proj.DB.DeleteWork(ctx, "w-retry") + + // Schedule a task that will fail + err := proj.DB.ScheduleTaskWithRetry(ctx, "w-retry", db.TaskTypeGitPush, time.Now(), map[string]string{ + "branch": "branch", + "dir": "/work/dir", + }, "retry-test", 3) + require.NoError(t, err) + + // Process tasks - task should fail but be rescheduled + control.ProcessAllDueTasksWithControlPlane(ctx, proj, mocks.CP) + }) + + t.Run("processes multiple tasks in order", func(t *testing.T) { + mocks := setupControlPlane() + + callOrder := []string{} + mocks.Git.PushSetUpstreamFunc = func(ctx context.Context, branch, dir string) error { + callOrder = append(callOrder, branch) + return nil + } + + createTestWork(ctx, t, proj.DB, "w-multi", "branch", "root-1") + defer proj.DB.DeleteWork(ctx, "w-multi") + + // Schedule multiple tasks + _, err := proj.DB.ScheduleTask(ctx, "w-multi", db.TaskTypeGitPush, time.Now(), map[string]string{ + "branch": "first", + "dir": "/dir1", + }) + require.NoError(t, err) + + _, err = proj.DB.ScheduleTask(ctx, "w-multi", db.TaskTypeGitPush, time.Now(), map[string]string{ + "branch": "second", + "dir": "/dir2", + }) + require.NoError(t, err) + + // Process tasks + control.ProcessAllDueTasksWithControlPlane(ctx, proj, mocks.CP) + + // Both tasks should be processed + calls := mocks.Git.PushSetUpstreamCalls() + assert.Len(t, calls, 2) + }) +} + +func TestHandleTaskError(t *testing.T) { + ctx := context.Background() + proj, cleanup := setupTestProject(t) + defer cleanup() + + t.Run("reschedules task with retries remaining", func(t *testing.T) { + createTestWork(ctx, t, proj.DB, "w-error-retry", "branch", "root-1") + defer proj.DB.DeleteWork(ctx, "w-error-retry") + + // Create a task with retries + err := proj.DB.ScheduleTaskWithRetry(ctx, "w-error-retry", db.TaskTypeGitPush, time.Now(), nil, "error-test", 3) + require.NoError(t, err) + + // Get the task + task, err := proj.DB.GetNextScheduledTask(ctx) + require.NoError(t, err) + require.NotNil(t, task) + + // Mark as executing first + err = proj.DB.MarkTaskExecuting(ctx, task.ID) + require.NoError(t, err) + + // Handle error - task should be rescheduled due to retries remaining + control.HandleTaskError(ctx, proj, task, "test error") + + // Verify no pending tasks (task was rescheduled with future time) + // The task should have been rescheduled with backoff, not marked as failed + }) + + t.Run("marks task as failed when retries exhausted", func(t *testing.T) { + createTestWork(ctx, t, proj.DB, "w-error-fail", "branch", "root-1") + defer proj.DB.DeleteWork(ctx, "w-error-fail") + + // Create a task with only 1 max attempt + err := proj.DB.ScheduleTaskWithRetry(ctx, "w-error-fail", db.TaskTypeGitPush, time.Now(), nil, "fail-test", 1) + require.NoError(t, err) + + // Get the task + task, err := proj.DB.GetNextScheduledTask(ctx) + require.NoError(t, err) + require.NotNil(t, task) + + // Mark as executing + err = proj.DB.MarkTaskExecuting(ctx, task.ID) + require.NoError(t, err) + + // Set attempt count to max (exhausted retries) + task.AttemptCount = 1 + + // Handle error - should mark as failed since retries exhausted + control.HandleTaskError(ctx, proj, task, "final error") + }) +} + +func TestProcessAllDueTasks(t *testing.T) { + ctx := context.Background() + proj, cleanup := setupTestProject(t) + defer cleanup() + + t.Run("uses default control plane", func(t *testing.T) { + createTestWork(ctx, t, proj.DB, "w-default-cp", "branch", "root-1") + defer proj.DB.DeleteWork(ctx, "w-default-cp") + + // Schedule a task that will fail because default dependencies hit real services + // But this tests that ProcessAllDueTasks correctly creates a default ControlPlane + err := proj.DB.ScheduleTaskWithRetry(ctx, "w-default-cp", db.TaskTypeGitPush, time.Now(), map[string]string{ + "branch": "branch", + "dir": "/nonexistent", + }, "default-cp-test", 1) + require.NoError(t, err) + + // Process tasks - should not panic even though task will fail + control.ProcessAllDueTasks(ctx, proj) + }) +} diff --git a/internal/control/plane.go b/internal/control/plane.go index aaba8c3e..d3b82141 100644 --- a/internal/control/plane.go +++ b/internal/control/plane.go @@ -1,3 +1,5 @@ +//go:generate moq -stub -out control_mock_test.go -pkg control_test . OrchestratorSpawner WorkDestroyer + package control import ( diff --git a/internal/feedback/feedback_mock.go b/internal/feedback/feedback_mock.go new file mode 100644 index 00000000..bba740a2 --- /dev/null +++ b/internal/feedback/feedback_mock.go @@ -0,0 +1,99 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package feedback + +import ( + "context" + "github.com/newhook/co/internal/db" + "github.com/newhook/co/internal/project" + "sync" +) + +// Ensure, that FeedbackProcessorMock does implement Processor. +// If this is not the case, regenerate this file with moq. +var _ Processor = &FeedbackProcessorMock{} + +// FeedbackProcessorMock is a mock implementation of Processor. +// +// func TestSomethingThatUsesProcessor(t *testing.T) { +// +// // make and configure a mocked Processor +// mockedProcessor := &FeedbackProcessorMock{ +// ProcessPRFeedbackFunc: func(ctx context.Context, proj *project.Project, database *db.DB, workID string) (int, error) { +// panic("mock out the ProcessPRFeedback method") +// }, +// } +// +// // use mockedProcessor in code that requires Processor +// // and then make assertions. +// +// } +type FeedbackProcessorMock struct { + // ProcessPRFeedbackFunc mocks the ProcessPRFeedback method. + ProcessPRFeedbackFunc func(ctx context.Context, proj *project.Project, database *db.DB, workID string) (int, error) + + // calls tracks calls to the methods. + calls struct { + // ProcessPRFeedback holds details about calls to the ProcessPRFeedback method. + ProcessPRFeedback []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Proj is the proj argument value. + Proj *project.Project + // Database is the database argument value. + Database *db.DB + // WorkID is the workID argument value. + WorkID string + } + } + lockProcessPRFeedback sync.RWMutex +} + +// ProcessPRFeedback calls ProcessPRFeedbackFunc. +func (mock *FeedbackProcessorMock) ProcessPRFeedback(ctx context.Context, proj *project.Project, database *db.DB, workID string) (int, error) { + callInfo := struct { + Ctx context.Context + Proj *project.Project + Database *db.DB + WorkID string + }{ + Ctx: ctx, + Proj: proj, + Database: database, + WorkID: workID, + } + mock.lockProcessPRFeedback.Lock() + mock.calls.ProcessPRFeedback = append(mock.calls.ProcessPRFeedback, callInfo) + mock.lockProcessPRFeedback.Unlock() + if mock.ProcessPRFeedbackFunc == nil { + var ( + nOut int + errOut error + ) + return nOut, errOut + } + return mock.ProcessPRFeedbackFunc(ctx, proj, database, workID) +} + +// ProcessPRFeedbackCalls gets all the calls that were made to ProcessPRFeedback. +// Check the length with: +// +// len(mockedProcessor.ProcessPRFeedbackCalls()) +func (mock *FeedbackProcessorMock) ProcessPRFeedbackCalls() []struct { + Ctx context.Context + Proj *project.Project + Database *db.DB + WorkID string +} { + var calls []struct { + Ctx context.Context + Proj *project.Project + Database *db.DB + WorkID string + } + mock.lockProcessPRFeedback.RLock() + calls = mock.calls.ProcessPRFeedback + mock.lockProcessPRFeedback.RUnlock() + return calls +} diff --git a/internal/feedback/interface.go b/internal/feedback/interface.go index d93c9c7a..c8812804 100644 --- a/internal/feedback/interface.go +++ b/internal/feedback/interface.go @@ -1,3 +1,5 @@ +//go:generate moq -stub -out feedback_mock.go . Processor:FeedbackProcessorMock + package feedback import ( diff --git a/internal/feedback/processor_test.go b/internal/feedback/processor_test.go index cb085a3f..135553cf 100644 --- a/internal/feedback/processor_test.go +++ b/internal/feedback/processor_test.go @@ -6,18 +6,15 @@ import ( "github.com/newhook/co/internal/db" "github.com/newhook/co/internal/github" + "github.com/stretchr/testify/require" ) func TestNewFeedbackProcessor(t *testing.T) { client := &github.Client{} processor := NewFeedbackProcessor(client) - if processor == nil { - t.Fatal("NewFeedbackProcessor returned nil") - } - if processor.client != client { - t.Error("Expected client to be set") - } + require.NotNil(t, processor, "NewFeedbackProcessor returned nil") + require.Equal(t, client, processor.client, "Expected client to be set") } func TestCategorizeCheckFailure(t *testing.T) { @@ -43,9 +40,7 @@ func TestCategorizeCheckFailure(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := processor.categorizeCheckFailure(tt.check) - if result != tt.expected { - t.Errorf("categorizeCheckFailure(%s) = %v, want %v", tt.check, result, tt.expected) - } + require.Equal(t, tt.expected, result) }) } } @@ -70,10 +65,7 @@ func TestCategorizeWorkflowFailure(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := processor.categorizeWorkflowFailure(tt.workflowName, tt.failureDetail) - if result != tt.expected { - t.Errorf("categorizeWorkflowFailure(%s, %s) = %v, want %v", - tt.workflowName, tt.failureDetail, result, tt.expected) - } + require.Equal(t, tt.expected, result) }) } } @@ -97,9 +89,7 @@ func TestGetPriorityForType(t *testing.T) { for _, tt := range tests { t.Run(string(tt.feedbackType), func(t *testing.T) { result := processor.getPriorityForType(tt.feedbackType) - if result != tt.expected { - t.Errorf("getPriorityForType(%v) = %d, want %d", tt.feedbackType, result, tt.expected) - } + require.Equal(t, tt.expected, result) }) } } @@ -129,9 +119,7 @@ func TestIsActionableComment(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := processor.isActionableComment(tt.body) - if result != tt.actionable { - t.Errorf("isActionableComment(%s) = %v, want %v", tt.body, result, tt.actionable) - } + require.Equal(t, tt.actionable, result) }) } } @@ -154,9 +142,7 @@ func TestTruncateText(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := processor.truncateText(tt.text, tt.maxLen) - if result != tt.expected { - t.Errorf("truncateText(%s, %d) = %s, want %s", tt.text, tt.maxLen, result, tt.expected) - } + require.Equal(t, tt.expected, result) }) } } @@ -190,25 +176,15 @@ func TestProcessStatusChecks(t *testing.T) { items := processor.processStatusChecks(status) // Should have 2 items (the two failures) - if len(items) != 2 { - t.Fatalf("Expected 2 feedback items, got %d", len(items)) - } + require.Len(t, items, 2) // Check first item (unit-tests failure) - if items[0].Type != github.FeedbackTypeTest { - t.Errorf("First item type = %v, want %v", items[0].Type, github.FeedbackTypeTest) - } - if items[0].Title != "Fix unit-tests failure" { - t.Errorf("First item title = %s, want 'Fix unit-tests failure'", items[0].Title) - } + require.Equal(t, github.FeedbackTypeTest, items[0].Type) + require.Equal(t, "Fix unit-tests failure", items[0].Title) // Check second item (lint error) - if items[1].Type != github.FeedbackTypeLint { - t.Errorf("Second item type = %v, want %v", items[1].Type, github.FeedbackTypeLint) - } - if items[1].Title != "Fix lint failure" { - t.Errorf("Second item title = %s, want 'Fix lint failure'", items[1].Title) - } + require.Equal(t, github.FeedbackTypeLint, items[1].Type) + require.Equal(t, "Fix lint failure", items[1].Title) } func TestProcessWorkflowRuns(t *testing.T) { @@ -252,17 +228,11 @@ func TestProcessWorkflowRuns(t *testing.T) { items := processor.processWorkflowRuns(ctx, "owner/repo", status) // Should have 1 item (the failed workflow with generic fallback) - if len(items) != 1 { - t.Fatalf("Expected 1 feedback item, got %d", len(items)) - } + require.Len(t, items, 1) - if items[0].Type != github.FeedbackTypeTest { - t.Errorf("Item type = %v, want %v", items[0].Type, github.FeedbackTypeTest) - } + require.Equal(t, github.FeedbackTypeTest, items[0].Type) // Generic fallback format: "Fix {jobName}: {stepName} in {workflowName}" - if items[0].Title != "Fix Unit Tests: Run tests in Test Suite" { - t.Errorf("Item title = %s, want 'Fix Unit Tests: Run tests in Test Suite'", items[0].Title) - } + require.Equal(t, "Fix Unit Tests: Run tests in Test Suite", items[0].Title) } func TestProcessReviews(t *testing.T) { @@ -303,28 +273,16 @@ func TestProcessReviews(t *testing.T) { items := processor.processReviews(status) // Should have 2 items (CHANGES_REQUESTED and the actionable comment) - if len(items) != 2 { - t.Fatalf("Expected 2 feedback items, got %d", len(items)) - } + require.Len(t, items, 2) // Check first item (CHANGES_REQUESTED) - if items[0].Type != github.FeedbackTypeReview { - t.Errorf("First item type = %v, want %v", items[0].Type, github.FeedbackTypeReview) - } - if items[0].Title != "Address review feedback from reviewer1" { - t.Errorf("First item title = %s", items[0].Title) - } - if items[0].Priority != 1 { - t.Errorf("First item priority = %d, want 1", items[0].Priority) - } + require.Equal(t, github.FeedbackTypeReview, items[0].Type) + require.Equal(t, "Address review feedback from reviewer1", items[0].Title) + require.Equal(t, 1, items[0].Priority) // Check second item (actionable comment) - if items[1].Type != github.FeedbackTypeReview { - t.Errorf("Second item type = %v, want %v", items[1].Type, github.FeedbackTypeReview) - } - if items[1].Priority != 2 { - t.Errorf("Second item priority = %d, want 2", items[1].Priority) - } + require.Equal(t, github.FeedbackTypeReview, items[1].Type) + require.Equal(t, 2, items[1].Priority) } func TestCreateGenericFailureItem(t *testing.T) { @@ -373,9 +331,7 @@ func TestCreateGenericFailureItem(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { item := processor.createGenericFailureItem(workflow, tt.job) - if item.Title != tt.expectedTitle { - t.Errorf("createGenericFailureItem().Title = %s, want %s", item.Title, tt.expectedTitle) - } + require.Equal(t, tt.expectedTitle, item.Title) }) } } @@ -422,14 +378,10 @@ func TestCategorizeComment_HumanVsBot(t *testing.T) { } feedbackType := processor.categorizeComment(comment) - if feedbackType != tt.expectedType { - t.Errorf("categorizeComment() type = %v, want %v", feedbackType, tt.expectedType) - } + require.Equal(t, tt.expectedType, feedbackType) priority := processor.getPriorityForType(feedbackType) - if priority != tt.expectedPriority { - t.Errorf("getPriorityForType(%v) = %d, want %d", feedbackType, priority, tt.expectedPriority) - } + require.Equal(t, tt.expectedPriority, priority) }) } } @@ -487,10 +439,7 @@ func TestCategorizeComment(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := processor.categorizeComment(tt.comment) - if result != tt.expected { - t.Errorf("categorizeComment(%+v) = %v, want %v", - tt.comment, result, tt.expected) - } + require.Equal(t, tt.expected, result) }) } } @@ -533,10 +482,7 @@ func TestExtractTitleFromComment(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := processor.extractTitleFromComment(tt.body) - if result != tt.expected { - t.Errorf("extractTitleFromComment(%s) = %s, want %s", - tt.body, result, tt.expected) - } + require.Equal(t, tt.expected, result) }) } } @@ -568,19 +514,13 @@ func TestProcessComments(t *testing.T) { items := processor.processComments(status) // Should have 2 actionable items (security and test failure) - if len(items) != 2 { - t.Fatalf("Expected 2 feedback items, got %d", len(items)) - } + require.Len(t, items, 2) // First should be security (higher priority) - if items[0].Type != github.FeedbackTypeSecurity { - t.Errorf("First item type = %v, want %v", items[0].Type, github.FeedbackTypeSecurity) - } + require.Equal(t, github.FeedbackTypeSecurity, items[0].Type) // Second should be test failure - if items[1].Type != github.FeedbackTypeTest { - t.Errorf("Second item type = %v, want %v", items[1].Type, github.FeedbackTypeTest) - } + require.Equal(t, github.FeedbackTypeTest, items[1].Type) } func TestProcessConflicts(t *testing.T) { @@ -607,24 +547,14 @@ func TestProcessConflicts(t *testing.T) { items := processor.processConflicts(status) - if len(items) != tt.expectItems { - t.Errorf("processConflicts() returned %d items, want %d", len(items), tt.expectItems) - } + require.Len(t, items, tt.expectItems) if tt.expectItems > 0 { item := items[0] - if item.Type != github.FeedbackTypeConflict { - t.Errorf("Item type = %v, want %v", item.Type, github.FeedbackTypeConflict) - } - if item.Title != "Resolve merge conflicts with main" { - t.Errorf("Item title = %s, want 'Resolve merge conflicts with main'", item.Title) - } - if item.Priority != 1 { - t.Errorf("Item priority = %d, want 1", item.Priority) - } - if item.Source.ID != "merge-conflict" { - t.Errorf("Item source ID = %s, want 'merge-conflict'", item.Source.ID) - } + require.Equal(t, github.FeedbackTypeConflict, item.Type) + require.Equal(t, "Resolve merge conflicts with main", item.Title) + require.Equal(t, 1, item.Priority) + require.Equal(t, "merge-conflict", item.Source.ID) } }) } @@ -634,9 +564,7 @@ func TestGetPriorityForConflictType(t *testing.T) { processor := &FeedbackProcessor{} result := processor.getPriorityForType(github.FeedbackTypeConflict) - if result != 1 { - t.Errorf("getPriorityForType(FeedbackTypeConflict) = %d, want 1", result) - } + require.Equal(t, 1, result) } func TestNewFeedbackProcessorWithProject(t *testing.T) { @@ -644,26 +572,16 @@ func TestNewFeedbackProcessorWithProject(t *testing.T) { t.Run("with nil project", func(t *testing.T) { processor := NewFeedbackProcessorWithProject(client, nil, "work-123") - if processor == nil { - t.Fatal("NewFeedbackProcessorWithProject returned nil") - } - if processor.proj != nil { - t.Error("Expected proj to be nil") - } - if processor.workID != "work-123" { - t.Errorf("workID = %s, want work-123", processor.workID) - } + require.NotNil(t, processor, "NewFeedbackProcessorWithProject returned nil") + require.Nil(t, processor.proj, "Expected proj to be nil") + require.Equal(t, "work-123", processor.workID) }) t.Run("stores all parameters", func(t *testing.T) { // Can't test with real project, but we can verify struct fields are set processor := NewFeedbackProcessorWithProject(client, nil, "w-abc") - if processor.client != client { - t.Error("Expected client to be set") - } - if processor.workID != "w-abc" { - t.Errorf("workID = %s, want w-abc", processor.workID) - } + require.Equal(t, client, processor.client, "Expected client to be set") + require.Equal(t, "w-abc", processor.workID) }) } @@ -672,16 +590,12 @@ func TestShouldUseClaude(t *testing.T) { t.Run("returns false when project is nil", func(t *testing.T) { processor := NewFeedbackProcessorWithProject(client, nil, "work-123") - if processor.shouldUseClaude() { - t.Error("Expected shouldUseClaude() to return false when project is nil") - } + require.False(t, processor.shouldUseClaude(), "Expected shouldUseClaude() to return false when project is nil") }) t.Run("returns false with basic processor", func(t *testing.T) { processor := NewFeedbackProcessor(client) - if processor.shouldUseClaude() { - t.Error("Expected shouldUseClaude() to return false for basic processor") - } + require.False(t, processor.shouldUseClaude(), "Expected shouldUseClaude() to return false for basic processor") }) } @@ -733,10 +647,7 @@ func TestTruncateLogContent(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := truncateLogContent(tt.logs, tt.maxBytes) - if result != tt.expected { - t.Errorf("truncateLogContent(%q, %d) = %q, want %q", - tt.logs, tt.maxBytes, result, tt.expected) - } + require.Equal(t, tt.expected, result) }) } } diff --git a/internal/feedback/status_test.go b/internal/feedback/status_test.go index e08b2611..114e9cb6 100644 --- a/internal/feedback/status_test.go +++ b/internal/feedback/status_test.go @@ -6,6 +6,7 @@ import ( "github.com/newhook/co/internal/db" "github.com/newhook/co/internal/github" + "github.com/stretchr/testify/require" ) func TestExtractCIStatus(t *testing.T) { @@ -111,9 +112,7 @@ func TestExtractCIStatus(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := extractCIStatus(tt.status) - if result != tt.expected { - t.Errorf("extractCIStatus() = %q, want %q", result, tt.expected) - } + require.Equal(t, tt.expected, result) }) } } @@ -235,22 +234,18 @@ func TestExtractApprovalStatus(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { status, approvers := extractApprovalStatus(tt.status) - if status != tt.expectedStatus { - t.Errorf("extractApprovalStatus() status = %q, want %q", status, tt.expectedStatus) - } + require.Equal(t, tt.expectedStatus, status) // Check approvers (order may vary) - if len(approvers) != len(tt.expectedApprovers) { - t.Errorf("extractApprovalStatus() approvers count = %d, want %d", len(approvers), len(tt.expectedApprovers)) - } else { + require.Len(t, approvers, len(tt.expectedApprovers)) + + if len(approvers) > 0 { approverSet := make(map[string]bool) for _, a := range approvers { approverSet[a] = true } for _, expected := range tt.expectedApprovers { - if !approverSet[expected] { - t.Errorf("extractApprovalStatus() missing approver %q", expected) - } + require.True(t, approverSet[expected], "missing approver %q", expected) } } }) @@ -322,15 +317,9 @@ func TestExtractStatusFromPRStatus(t *testing.T) { t.Run(tt.name, func(t *testing.T) { info := ExtractStatusFromPRStatus(tt.status) - if info.CIStatus != tt.expectedCI { - t.Errorf("CIStatus = %q, want %q", info.CIStatus, tt.expectedCI) - } - if info.ApprovalStatus != tt.expectedApproval { - t.Errorf("ApprovalStatus = %q, want %q", info.ApprovalStatus, tt.expectedApproval) - } - if len(info.Approvers) != len(tt.expectedApprovers) { - t.Errorf("Approvers count = %d, want %d", len(info.Approvers), len(tt.expectedApprovers)) - } + require.Equal(t, tt.expectedCI, info.CIStatus) + require.Equal(t, tt.expectedApproval, info.ApprovalStatus) + require.Len(t, info.Approvers, len(tt.expectedApprovers)) }) } } diff --git a/internal/git/git.go b/internal/git/git.go index 346c5741..a6878b49 100644 --- a/internal/git/git.go +++ b/internal/git/git.go @@ -1,5 +1,7 @@ package git +//go:generate moq -stub -out git_mock.go . Operations:GitOperationsMock + import ( "context" "fmt" diff --git a/internal/git/git_mock.go b/internal/git/git_mock.go new file mode 100644 index 00000000..86e63e8b --- /dev/null +++ b/internal/git/git_mock.go @@ -0,0 +1,500 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package git + +import ( + "context" + "sync" +) + +// Ensure, that GitOperationsMock does implement Operations. +// If this is not the case, regenerate this file with moq. +var _ Operations = &GitOperationsMock{} + +// GitOperationsMock is a mock implementation of Operations. +// +// func TestSomethingThatUsesOperations(t *testing.T) { +// +// // make and configure a mocked Operations +// mockedOperations := &GitOperationsMock{ +// BranchExistsFunc: func(ctx context.Context, repoPath string, branchName string) bool { +// panic("mock out the BranchExists method") +// }, +// CloneFunc: func(ctx context.Context, source string, dest string) error { +// panic("mock out the Clone method") +// }, +// FetchBranchFunc: func(ctx context.Context, repoPath string, branch string) error { +// panic("mock out the FetchBranch method") +// }, +// FetchPRRefFunc: func(ctx context.Context, repoPath string, prNumber int, localBranch string) error { +// panic("mock out the FetchPRRef method") +// }, +// ListBranchesFunc: func(ctx context.Context, repoPath string) ([]string, error) { +// panic("mock out the ListBranches method") +// }, +// PullFunc: func(ctx context.Context, dir string) error { +// panic("mock out the Pull method") +// }, +// PushSetUpstreamFunc: func(ctx context.Context, branch string, dir string) error { +// panic("mock out the PushSetUpstream method") +// }, +// ValidateExistingBranchFunc: func(ctx context.Context, repoPath string, branchName string) (bool, bool, error) { +// panic("mock out the ValidateExistingBranch method") +// }, +// } +// +// // use mockedOperations in code that requires Operations +// // and then make assertions. +// +// } +type GitOperationsMock struct { + // BranchExistsFunc mocks the BranchExists method. + BranchExistsFunc func(ctx context.Context, repoPath string, branchName string) bool + + // CloneFunc mocks the Clone method. + CloneFunc func(ctx context.Context, source string, dest string) error + + // FetchBranchFunc mocks the FetchBranch method. + FetchBranchFunc func(ctx context.Context, repoPath string, branch string) error + + // FetchPRRefFunc mocks the FetchPRRef method. + FetchPRRefFunc func(ctx context.Context, repoPath string, prNumber int, localBranch string) error + + // ListBranchesFunc mocks the ListBranches method. + ListBranchesFunc func(ctx context.Context, repoPath string) ([]string, error) + + // PullFunc mocks the Pull method. + PullFunc func(ctx context.Context, dir string) error + + // PushSetUpstreamFunc mocks the PushSetUpstream method. + PushSetUpstreamFunc func(ctx context.Context, branch string, dir string) error + + // ValidateExistingBranchFunc mocks the ValidateExistingBranch method. + ValidateExistingBranchFunc func(ctx context.Context, repoPath string, branchName string) (bool, bool, error) + + // calls tracks calls to the methods. + calls struct { + // BranchExists holds details about calls to the BranchExists method. + BranchExists []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // RepoPath is the repoPath argument value. + RepoPath string + // BranchName is the branchName argument value. + BranchName string + } + // Clone holds details about calls to the Clone method. + Clone []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Source is the source argument value. + Source string + // Dest is the dest argument value. + Dest string + } + // FetchBranch holds details about calls to the FetchBranch method. + FetchBranch []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // RepoPath is the repoPath argument value. + RepoPath string + // Branch is the branch argument value. + Branch string + } + // FetchPRRef holds details about calls to the FetchPRRef method. + FetchPRRef []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // RepoPath is the repoPath argument value. + RepoPath string + // PrNumber is the prNumber argument value. + PrNumber int + // LocalBranch is the localBranch argument value. + LocalBranch string + } + // ListBranches holds details about calls to the ListBranches method. + ListBranches []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // RepoPath is the repoPath argument value. + RepoPath string + } + // Pull holds details about calls to the Pull method. + Pull []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Dir is the dir argument value. + Dir string + } + // PushSetUpstream holds details about calls to the PushSetUpstream method. + PushSetUpstream []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Branch is the branch argument value. + Branch string + // Dir is the dir argument value. + Dir string + } + // ValidateExistingBranch holds details about calls to the ValidateExistingBranch method. + ValidateExistingBranch []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // RepoPath is the repoPath argument value. + RepoPath string + // BranchName is the branchName argument value. + BranchName string + } + } + lockBranchExists sync.RWMutex + lockClone sync.RWMutex + lockFetchBranch sync.RWMutex + lockFetchPRRef sync.RWMutex + lockListBranches sync.RWMutex + lockPull sync.RWMutex + lockPushSetUpstream sync.RWMutex + lockValidateExistingBranch sync.RWMutex +} + +// BranchExists calls BranchExistsFunc. +func (mock *GitOperationsMock) BranchExists(ctx context.Context, repoPath string, branchName string) bool { + callInfo := struct { + Ctx context.Context + RepoPath string + BranchName string + }{ + Ctx: ctx, + RepoPath: repoPath, + BranchName: branchName, + } + mock.lockBranchExists.Lock() + mock.calls.BranchExists = append(mock.calls.BranchExists, callInfo) + mock.lockBranchExists.Unlock() + if mock.BranchExistsFunc == nil { + var ( + bOut bool + ) + return bOut + } + return mock.BranchExistsFunc(ctx, repoPath, branchName) +} + +// BranchExistsCalls gets all the calls that were made to BranchExists. +// Check the length with: +// +// len(mockedOperations.BranchExistsCalls()) +func (mock *GitOperationsMock) BranchExistsCalls() []struct { + Ctx context.Context + RepoPath string + BranchName string +} { + var calls []struct { + Ctx context.Context + RepoPath string + BranchName string + } + mock.lockBranchExists.RLock() + calls = mock.calls.BranchExists + mock.lockBranchExists.RUnlock() + return calls +} + +// Clone calls CloneFunc. +func (mock *GitOperationsMock) Clone(ctx context.Context, source string, dest string) error { + callInfo := struct { + Ctx context.Context + Source string + Dest string + }{ + Ctx: ctx, + Source: source, + Dest: dest, + } + mock.lockClone.Lock() + mock.calls.Clone = append(mock.calls.Clone, callInfo) + mock.lockClone.Unlock() + if mock.CloneFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.CloneFunc(ctx, source, dest) +} + +// CloneCalls gets all the calls that were made to Clone. +// Check the length with: +// +// len(mockedOperations.CloneCalls()) +func (mock *GitOperationsMock) CloneCalls() []struct { + Ctx context.Context + Source string + Dest string +} { + var calls []struct { + Ctx context.Context + Source string + Dest string + } + mock.lockClone.RLock() + calls = mock.calls.Clone + mock.lockClone.RUnlock() + return calls +} + +// FetchBranch calls FetchBranchFunc. +func (mock *GitOperationsMock) FetchBranch(ctx context.Context, repoPath string, branch string) error { + callInfo := struct { + Ctx context.Context + RepoPath string + Branch string + }{ + Ctx: ctx, + RepoPath: repoPath, + Branch: branch, + } + mock.lockFetchBranch.Lock() + mock.calls.FetchBranch = append(mock.calls.FetchBranch, callInfo) + mock.lockFetchBranch.Unlock() + if mock.FetchBranchFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.FetchBranchFunc(ctx, repoPath, branch) +} + +// FetchBranchCalls gets all the calls that were made to FetchBranch. +// Check the length with: +// +// len(mockedOperations.FetchBranchCalls()) +func (mock *GitOperationsMock) FetchBranchCalls() []struct { + Ctx context.Context + RepoPath string + Branch string +} { + var calls []struct { + Ctx context.Context + RepoPath string + Branch string + } + mock.lockFetchBranch.RLock() + calls = mock.calls.FetchBranch + mock.lockFetchBranch.RUnlock() + return calls +} + +// FetchPRRef calls FetchPRRefFunc. +func (mock *GitOperationsMock) FetchPRRef(ctx context.Context, repoPath string, prNumber int, localBranch string) error { + callInfo := struct { + Ctx context.Context + RepoPath string + PrNumber int + LocalBranch string + }{ + Ctx: ctx, + RepoPath: repoPath, + PrNumber: prNumber, + LocalBranch: localBranch, + } + mock.lockFetchPRRef.Lock() + mock.calls.FetchPRRef = append(mock.calls.FetchPRRef, callInfo) + mock.lockFetchPRRef.Unlock() + if mock.FetchPRRefFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.FetchPRRefFunc(ctx, repoPath, prNumber, localBranch) +} + +// FetchPRRefCalls gets all the calls that were made to FetchPRRef. +// Check the length with: +// +// len(mockedOperations.FetchPRRefCalls()) +func (mock *GitOperationsMock) FetchPRRefCalls() []struct { + Ctx context.Context + RepoPath string + PrNumber int + LocalBranch string +} { + var calls []struct { + Ctx context.Context + RepoPath string + PrNumber int + LocalBranch string + } + mock.lockFetchPRRef.RLock() + calls = mock.calls.FetchPRRef + mock.lockFetchPRRef.RUnlock() + return calls +} + +// ListBranches calls ListBranchesFunc. +func (mock *GitOperationsMock) ListBranches(ctx context.Context, repoPath string) ([]string, error) { + callInfo := struct { + Ctx context.Context + RepoPath string + }{ + Ctx: ctx, + RepoPath: repoPath, + } + mock.lockListBranches.Lock() + mock.calls.ListBranches = append(mock.calls.ListBranches, callInfo) + mock.lockListBranches.Unlock() + if mock.ListBranchesFunc == nil { + var ( + stringsOut []string + errOut error + ) + return stringsOut, errOut + } + return mock.ListBranchesFunc(ctx, repoPath) +} + +// ListBranchesCalls gets all the calls that were made to ListBranches. +// Check the length with: +// +// len(mockedOperations.ListBranchesCalls()) +func (mock *GitOperationsMock) ListBranchesCalls() []struct { + Ctx context.Context + RepoPath string +} { + var calls []struct { + Ctx context.Context + RepoPath string + } + mock.lockListBranches.RLock() + calls = mock.calls.ListBranches + mock.lockListBranches.RUnlock() + return calls +} + +// Pull calls PullFunc. +func (mock *GitOperationsMock) Pull(ctx context.Context, dir string) error { + callInfo := struct { + Ctx context.Context + Dir string + }{ + Ctx: ctx, + Dir: dir, + } + mock.lockPull.Lock() + mock.calls.Pull = append(mock.calls.Pull, callInfo) + mock.lockPull.Unlock() + if mock.PullFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.PullFunc(ctx, dir) +} + +// PullCalls gets all the calls that were made to Pull. +// Check the length with: +// +// len(mockedOperations.PullCalls()) +func (mock *GitOperationsMock) PullCalls() []struct { + Ctx context.Context + Dir string +} { + var calls []struct { + Ctx context.Context + Dir string + } + mock.lockPull.RLock() + calls = mock.calls.Pull + mock.lockPull.RUnlock() + return calls +} + +// PushSetUpstream calls PushSetUpstreamFunc. +func (mock *GitOperationsMock) PushSetUpstream(ctx context.Context, branch string, dir string) error { + callInfo := struct { + Ctx context.Context + Branch string + Dir string + }{ + Ctx: ctx, + Branch: branch, + Dir: dir, + } + mock.lockPushSetUpstream.Lock() + mock.calls.PushSetUpstream = append(mock.calls.PushSetUpstream, callInfo) + mock.lockPushSetUpstream.Unlock() + if mock.PushSetUpstreamFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.PushSetUpstreamFunc(ctx, branch, dir) +} + +// PushSetUpstreamCalls gets all the calls that were made to PushSetUpstream. +// Check the length with: +// +// len(mockedOperations.PushSetUpstreamCalls()) +func (mock *GitOperationsMock) PushSetUpstreamCalls() []struct { + Ctx context.Context + Branch string + Dir string +} { + var calls []struct { + Ctx context.Context + Branch string + Dir string + } + mock.lockPushSetUpstream.RLock() + calls = mock.calls.PushSetUpstream + mock.lockPushSetUpstream.RUnlock() + return calls +} + +// ValidateExistingBranch calls ValidateExistingBranchFunc. +func (mock *GitOperationsMock) ValidateExistingBranch(ctx context.Context, repoPath string, branchName string) (bool, bool, error) { + callInfo := struct { + Ctx context.Context + RepoPath string + BranchName string + }{ + Ctx: ctx, + RepoPath: repoPath, + BranchName: branchName, + } + mock.lockValidateExistingBranch.Lock() + mock.calls.ValidateExistingBranch = append(mock.calls.ValidateExistingBranch, callInfo) + mock.lockValidateExistingBranch.Unlock() + if mock.ValidateExistingBranchFunc == nil { + var ( + existsLocalOut bool + existsRemoteOut bool + errOut error + ) + return existsLocalOut, existsRemoteOut, errOut + } + return mock.ValidateExistingBranchFunc(ctx, repoPath, branchName) +} + +// ValidateExistingBranchCalls gets all the calls that were made to ValidateExistingBranch. +// Check the length with: +// +// len(mockedOperations.ValidateExistingBranchCalls()) +func (mock *GitOperationsMock) ValidateExistingBranchCalls() []struct { + Ctx context.Context + RepoPath string + BranchName string +} { + var calls []struct { + Ctx context.Context + RepoPath string + BranchName string + } + mock.lockValidateExistingBranch.RLock() + calls = mock.calls.ValidateExistingBranch + mock.lockValidateExistingBranch.RUnlock() + return calls +} diff --git a/internal/git/git_test.go b/internal/git/git_test.go new file mode 100644 index 00000000..648f6ccd --- /dev/null +++ b/internal/git/git_test.go @@ -0,0 +1,13 @@ +package git_test + +import ( + "testing" + + "github.com/newhook/co/internal/git" + "github.com/stretchr/testify/require" +) + +func TestNewOperations(t *testing.T) { + ops := git.NewOperations() + require.NotNil(t, ops, "NewOperations returned nil") +} diff --git a/internal/github/client.go b/internal/github/client.go index baff969f..fdd72633 100644 --- a/internal/github/client.go +++ b/internal/github/client.go @@ -1,5 +1,7 @@ package github +//go:generate moq -stub -out github_mock.go . ClientInterface:GitHubClientMock + import ( "context" "encoding/json" diff --git a/internal/github/client_comment_test.go b/internal/github/client_comment_test.go index 7a059877..1e116f68 100644 --- a/internal/github/client_comment_test.go +++ b/internal/github/client_comment_test.go @@ -5,6 +5,8 @@ import ( "fmt" "os" "testing" + + "github.com/stretchr/testify/require" ) func TestPostPRComment(t *testing.T) { @@ -18,18 +20,14 @@ func TestPostPRComment(t *testing.T) { // You need to set a real PR URL here when testing manually prURL := os.Getenv("TEST_PR_URL") - if prURL == "" { - t.Fatal("TEST_PR_URL environment variable must be set for manual testing") - } + require.NotEmpty(t, prURL, "TEST_PR_URL environment variable must be set for manual testing") client := NewClient() ctx := context.Background() // Test posting a simple comment err := client.PostPRComment(ctx, prURL, "Test comment from co integration test") - if err != nil { - t.Fatalf("Failed to post PR comment: %v", err) - } + require.NoError(t, err, "Failed to post PR comment") t.Log("Successfully posted PR comment") } @@ -44,9 +42,7 @@ func TestPostReplyToComment(t *testing.T) { } prURL := os.Getenv("TEST_PR_URL") - if prURL == "" { - t.Fatal("TEST_PR_URL environment variable must be set for manual testing") - } + require.NotEmpty(t, prURL, "TEST_PR_URL environment variable must be set for manual testing") // You need to set a real comment ID here when testing manually commentID := 123456789 // Replace with actual comment ID @@ -56,9 +52,7 @@ func TestPostReplyToComment(t *testing.T) { // Test posting a reply to a comment err := client.PostReplyToComment(ctx, prURL, commentID, "Test reply from co integration test") - if err != nil { - t.Fatalf("Failed to post reply to comment: %v", err) - } + require.NoError(t, err, "Failed to post reply to comment") t.Log("Successfully posted reply to comment") } @@ -73,9 +67,7 @@ func TestCommentIntegration(t *testing.T) { } prURL := os.Getenv("TEST_PR_URL") - if prURL == "" { - t.Fatal("TEST_PR_URL environment variable must be set for manual testing") - } + require.NotEmpty(t, prURL, "TEST_PR_URL environment variable must be set for manual testing") client := NewClient() ctx := context.Background() @@ -90,9 +82,7 @@ func TestCommentIntegration(t *testing.T) { beadID, feedbackTitle, priority) err := client.PostPRComment(ctx, prURL, ackMessage) - if err != nil { - t.Fatalf("Failed to post acknowledgment: %v", err) - } + require.NoError(t, err, "Failed to post acknowledgment") t.Log("Successfully posted bead acknowledgment to PR") -} \ No newline at end of file +} diff --git a/internal/github/client_test.go b/internal/github/client_test.go index 27d989be..987df9be 100644 --- a/internal/github/client_test.go +++ b/internal/github/client_test.go @@ -3,22 +3,22 @@ package github import ( "context" "testing" + + "github.com/stretchr/testify/require" ) func TestNewClient(t *testing.T) { client := NewClient() - if client == nil { - t.Fatal("NewClient returned nil") - } + require.NotNil(t, client, "NewClient returned nil") } func TestParsePRURL(t *testing.T) { tests := []struct { - name string - prURL string - wantNumber string - wantRepo string - wantErr bool + name string + prURL string + wantNumber string + wantRepo string + wantErr bool }{ { name: "Valid GitHub PR URL", @@ -117,18 +117,14 @@ func TestParsePRURL(t *testing.T) { t.Run(tt.name, func(t *testing.T) { prNumber, repo, err := parsePRURL(tt.prURL) - if (err != nil) != tt.wantErr { - t.Errorf("parsePRURL() error = %v, wantErr %v", err, tt.wantErr) - return + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) } - if prNumber != tt.wantNumber { - t.Errorf("parsePRURL() prNumber = %v, want %v", prNumber, tt.wantNumber) - } - - if repo != tt.wantRepo { - t.Errorf("parsePRURL() repo = %v, want %v", repo, tt.wantRepo) - } + require.Equal(t, tt.wantNumber, prNumber) + require.Equal(t, tt.wantRepo, repo) }) } } @@ -145,12 +141,8 @@ func TestGetPRStatus(t *testing.T) { status, err := client.GetPRStatus(ctx, prURL) if err == nil { - if status == nil { - t.Error("GetPRStatus returned nil status without error") - } - if status.URL != prURL { - t.Errorf("Status URL = %s, want %s", status.URL, prURL) - } + require.NotNil(t, status, "GetPRStatus returned nil status without error") + require.Equal(t, prURL, status.URL) } }) @@ -158,9 +150,7 @@ func TestGetPRStatus(t *testing.T) { prURL := "not-a-valid-url" _, err := client.GetPRStatus(ctx, prURL) - if err == nil { - t.Error("Expected error for invalid PR URL") - } + require.Error(t, err, "Expected error for invalid PR URL") }) } @@ -177,12 +167,8 @@ func TestFetchPRInfo(t *testing.T) { if err == nil { // Basic validation of the status fields - if status.State == "" { - t.Error("PR state should not be empty") - } - if status.MergeableState == "" { - t.Error("PR mergeable state should not be empty") - } + require.NotEmpty(t, status.State, "PR state should not be empty") + require.NotEmpty(t, status.MergeableState, "PR mergeable state should not be empty") } }) } @@ -200,9 +186,7 @@ func TestFetchStatusChecks(t *testing.T) { if err == nil { // Status checks might be empty, which is valid - if status.StatusChecks == nil { - t.Error("StatusChecks should be initialized even if empty") - } + require.NotNil(t, status.StatusChecks, "StatusChecks should be initialized even if empty") } }) } @@ -220,9 +204,7 @@ func TestFetchComments(t *testing.T) { if err == nil { // Comments might be empty, which is valid - if status.Comments == nil { - t.Error("Comments should be initialized even if empty") - } + require.NotNil(t, status.Comments, "Comments should be initialized even if empty") } }) } @@ -240,15 +222,11 @@ func TestFetchReviews(t *testing.T) { if err == nil { // Reviews might be empty, which is valid - if status.Reviews == nil { - t.Error("Reviews should be initialized even if empty") - } + require.NotNil(t, status.Reviews, "Reviews should be initialized even if empty") // Check that review comments are fetched for each review for _, review := range status.Reviews { - if review.Comments == nil { - t.Error("Review comments should be initialized even if empty") - } + require.NotNil(t, review.Comments, "Review comments should be initialized even if empty") } } }) @@ -267,21 +245,15 @@ func TestFetchWorkflowRuns(t *testing.T) { if err == nil { // Workflows might be empty, which is valid - if status.Workflows == nil { - t.Error("Workflows should be initialized even if empty") - } + require.NotNil(t, status.Workflows, "Workflows should be initialized even if empty") // Check that jobs are fetched for each workflow for _, workflow := range status.Workflows { - if workflow.Jobs == nil { - t.Error("Workflow jobs should be initialized even if empty") - } + require.NotNil(t, workflow.Jobs, "Workflow jobs should be initialized even if empty") // Check that steps are fetched for each job for _, job := range workflow.Jobs { - if job.Steps == nil { - t.Error("Job steps should be initialized even if empty") - } + require.NotNil(t, job.Steps, "Job steps should be initialized even if empty") } } } @@ -302,18 +274,10 @@ func TestPRStatusStructure(t *testing.T) { Workflows: []WorkflowRun{}, } - if status.URL != "https://github.com/owner/repo/pull/123" { - t.Error("PRStatus URL not set correctly") - } - if status.State != "OPEN" { - t.Error("PRStatus State not set correctly") - } - if !status.Mergeable { - t.Error("PRStatus Mergeable not set correctly") - } - if status.MergeableState != "clean" { - t.Error("PRStatus MergeableState not set correctly") - } + require.Equal(t, "https://github.com/owner/repo/pull/123", status.URL, "PRStatus URL not set correctly") + require.Equal(t, "OPEN", status.State, "PRStatus State not set correctly") + require.True(t, status.Mergeable, "PRStatus Mergeable not set correctly") + require.Equal(t, "clean", status.MergeableState, "PRStatus MergeableState not set correctly") } func TestStatusCheckStructure(t *testing.T) { @@ -324,18 +288,10 @@ func TestStatusCheckStructure(t *testing.T) { TargetURL: "https://travis-ci.org/owner/repo/builds/123", } - if check.Context != "continuous-integration/travis-ci" { - t.Error("StatusCheck Context not set correctly") - } - if check.State != "SUCCESS" { - t.Error("StatusCheck State not set correctly") - } - if check.Description != "The Travis CI build passed" { - t.Error("StatusCheck Description not set correctly") - } - if check.TargetURL != "https://travis-ci.org/owner/repo/builds/123" { - t.Error("StatusCheck TargetURL not set correctly") - } + require.Equal(t, "continuous-integration/travis-ci", check.Context, "StatusCheck Context not set correctly") + require.Equal(t, "SUCCESS", check.State, "StatusCheck State not set correctly") + require.Equal(t, "The Travis CI build passed", check.Description, "StatusCheck Description not set correctly") + require.Equal(t, "https://travis-ci.org/owner/repo/builds/123", check.TargetURL, "StatusCheck TargetURL not set correctly") } func TestWorkflowRunStructure(t *testing.T) { @@ -363,21 +319,11 @@ func TestWorkflowRunStructure(t *testing.T) { }, } - if workflow.ID != 123456 { - t.Error("WorkflowRun ID not set correctly") - } - if workflow.Name != "CI Pipeline" { - t.Error("WorkflowRun Name not set correctly") - } - if len(workflow.Jobs) != 1 { - t.Error("WorkflowRun Jobs not set correctly") - } - if len(workflow.Jobs[0].Steps) != 1 { - t.Error("Job Steps not set correctly") - } - if workflow.Jobs[0].Steps[0].Number != 3 { - t.Error("Step Number not set correctly") - } + require.Equal(t, int64(123456), workflow.ID, "WorkflowRun ID not set correctly") + require.Equal(t, "CI Pipeline", workflow.Name, "WorkflowRun Name not set correctly") + require.Len(t, workflow.Jobs, 1, "WorkflowRun Jobs not set correctly") + require.Len(t, workflow.Jobs[0].Steps, 1, "Job Steps not set correctly") + require.Equal(t, 3, workflow.Jobs[0].Steps[0].Number, "Step Number not set correctly") } func TestReviewStructure(t *testing.T) { @@ -397,22 +343,10 @@ func TestReviewStructure(t *testing.T) { }, } - if review.ID != 999 { - t.Error("Review ID not set correctly") - } - if review.State != "APPROVED" { - t.Error("Review State not set correctly") - } - if review.Body != "LGTM!" { - t.Error("Review Body not set correctly") - } - if review.Author != "reviewer1" { - t.Error("Review Author not set correctly") - } - if len(review.Comments) != 1 { - t.Error("Review Comments not set correctly") - } - if review.Comments[0].Line != 42 { - t.Error("ReviewComment Line not set correctly") - } -} \ No newline at end of file + require.Equal(t, 999, review.ID, "Review ID not set correctly") + require.Equal(t, "APPROVED", review.State, "Review State not set correctly") + require.Equal(t, "LGTM!", review.Body, "Review Body not set correctly") + require.Equal(t, "reviewer1", review.Author, "Review Author not set correctly") + require.Len(t, review.Comments, 1, "Review Comments not set correctly") + require.Equal(t, 42, review.Comments[0].Line, "ReviewComment Line not set correctly") +} diff --git a/internal/github/github_mock.go b/internal/github/github_mock.go new file mode 100644 index 00000000..d4cfd0d0 --- /dev/null +++ b/internal/github/github_mock.go @@ -0,0 +1,453 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package github + +import ( + "context" + "sync" +) + +// Ensure, that GitHubClientMock does implement ClientInterface. +// If this is not the case, regenerate this file with moq. +var _ ClientInterface = &GitHubClientMock{} + +// GitHubClientMock is a mock implementation of ClientInterface. +// +// func TestSomethingThatUsesClientInterface(t *testing.T) { +// +// // make and configure a mocked ClientInterface +// mockedClientInterface := &GitHubClientMock{ +// GetJobLogsFunc: func(ctx context.Context, repo string, jobID int64) (string, error) { +// panic("mock out the GetJobLogs method") +// }, +// GetPRMetadataFunc: func(ctx context.Context, prURLOrNumber string, repo string) (*PRMetadata, error) { +// panic("mock out the GetPRMetadata method") +// }, +// GetPRStatusFunc: func(ctx context.Context, prURL string) (*PRStatus, error) { +// panic("mock out the GetPRStatus method") +// }, +// PostPRCommentFunc: func(ctx context.Context, prURL string, body string) error { +// panic("mock out the PostPRComment method") +// }, +// PostReplyToCommentFunc: func(ctx context.Context, prURL string, commentID int, body string) error { +// panic("mock out the PostReplyToComment method") +// }, +// PostReviewReplyFunc: func(ctx context.Context, prURL string, reviewCommentID int, body string) error { +// panic("mock out the PostReviewReply method") +// }, +// ResolveReviewThreadFunc: func(ctx context.Context, prURL string, commentID int) error { +// panic("mock out the ResolveReviewThread method") +// }, +// } +// +// // use mockedClientInterface in code that requires ClientInterface +// // and then make assertions. +// +// } +type GitHubClientMock struct { + // GetJobLogsFunc mocks the GetJobLogs method. + GetJobLogsFunc func(ctx context.Context, repo string, jobID int64) (string, error) + + // GetPRMetadataFunc mocks the GetPRMetadata method. + GetPRMetadataFunc func(ctx context.Context, prURLOrNumber string, repo string) (*PRMetadata, error) + + // GetPRStatusFunc mocks the GetPRStatus method. + GetPRStatusFunc func(ctx context.Context, prURL string) (*PRStatus, error) + + // PostPRCommentFunc mocks the PostPRComment method. + PostPRCommentFunc func(ctx context.Context, prURL string, body string) error + + // PostReplyToCommentFunc mocks the PostReplyToComment method. + PostReplyToCommentFunc func(ctx context.Context, prURL string, commentID int, body string) error + + // PostReviewReplyFunc mocks the PostReviewReply method. + PostReviewReplyFunc func(ctx context.Context, prURL string, reviewCommentID int, body string) error + + // ResolveReviewThreadFunc mocks the ResolveReviewThread method. + ResolveReviewThreadFunc func(ctx context.Context, prURL string, commentID int) error + + // calls tracks calls to the methods. + calls struct { + // GetJobLogs holds details about calls to the GetJobLogs method. + GetJobLogs []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Repo is the repo argument value. + Repo string + // JobID is the jobID argument value. + JobID int64 + } + // GetPRMetadata holds details about calls to the GetPRMetadata method. + GetPRMetadata []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // PrURLOrNumber is the prURLOrNumber argument value. + PrURLOrNumber string + // Repo is the repo argument value. + Repo string + } + // GetPRStatus holds details about calls to the GetPRStatus method. + GetPRStatus []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // PrURL is the prURL argument value. + PrURL string + } + // PostPRComment holds details about calls to the PostPRComment method. + PostPRComment []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // PrURL is the prURL argument value. + PrURL string + // Body is the body argument value. + Body string + } + // PostReplyToComment holds details about calls to the PostReplyToComment method. + PostReplyToComment []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // PrURL is the prURL argument value. + PrURL string + // CommentID is the commentID argument value. + CommentID int + // Body is the body argument value. + Body string + } + // PostReviewReply holds details about calls to the PostReviewReply method. + PostReviewReply []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // PrURL is the prURL argument value. + PrURL string + // ReviewCommentID is the reviewCommentID argument value. + ReviewCommentID int + // Body is the body argument value. + Body string + } + // ResolveReviewThread holds details about calls to the ResolveReviewThread method. + ResolveReviewThread []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // PrURL is the prURL argument value. + PrURL string + // CommentID is the commentID argument value. + CommentID int + } + } + lockGetJobLogs sync.RWMutex + lockGetPRMetadata sync.RWMutex + lockGetPRStatus sync.RWMutex + lockPostPRComment sync.RWMutex + lockPostReplyToComment sync.RWMutex + lockPostReviewReply sync.RWMutex + lockResolveReviewThread sync.RWMutex +} + +// GetJobLogs calls GetJobLogsFunc. +func (mock *GitHubClientMock) GetJobLogs(ctx context.Context, repo string, jobID int64) (string, error) { + callInfo := struct { + Ctx context.Context + Repo string + JobID int64 + }{ + Ctx: ctx, + Repo: repo, + JobID: jobID, + } + mock.lockGetJobLogs.Lock() + mock.calls.GetJobLogs = append(mock.calls.GetJobLogs, callInfo) + mock.lockGetJobLogs.Unlock() + if mock.GetJobLogsFunc == nil { + var ( + sOut string + errOut error + ) + return sOut, errOut + } + return mock.GetJobLogsFunc(ctx, repo, jobID) +} + +// GetJobLogsCalls gets all the calls that were made to GetJobLogs. +// Check the length with: +// +// len(mockedClientInterface.GetJobLogsCalls()) +func (mock *GitHubClientMock) GetJobLogsCalls() []struct { + Ctx context.Context + Repo string + JobID int64 +} { + var calls []struct { + Ctx context.Context + Repo string + JobID int64 + } + mock.lockGetJobLogs.RLock() + calls = mock.calls.GetJobLogs + mock.lockGetJobLogs.RUnlock() + return calls +} + +// GetPRMetadata calls GetPRMetadataFunc. +func (mock *GitHubClientMock) GetPRMetadata(ctx context.Context, prURLOrNumber string, repo string) (*PRMetadata, error) { + callInfo := struct { + Ctx context.Context + PrURLOrNumber string + Repo string + }{ + Ctx: ctx, + PrURLOrNumber: prURLOrNumber, + Repo: repo, + } + mock.lockGetPRMetadata.Lock() + mock.calls.GetPRMetadata = append(mock.calls.GetPRMetadata, callInfo) + mock.lockGetPRMetadata.Unlock() + if mock.GetPRMetadataFunc == nil { + var ( + pRMetadataOut *PRMetadata + errOut error + ) + return pRMetadataOut, errOut + } + return mock.GetPRMetadataFunc(ctx, prURLOrNumber, repo) +} + +// GetPRMetadataCalls gets all the calls that were made to GetPRMetadata. +// Check the length with: +// +// len(mockedClientInterface.GetPRMetadataCalls()) +func (mock *GitHubClientMock) GetPRMetadataCalls() []struct { + Ctx context.Context + PrURLOrNumber string + Repo string +} { + var calls []struct { + Ctx context.Context + PrURLOrNumber string + Repo string + } + mock.lockGetPRMetadata.RLock() + calls = mock.calls.GetPRMetadata + mock.lockGetPRMetadata.RUnlock() + return calls +} + +// GetPRStatus calls GetPRStatusFunc. +func (mock *GitHubClientMock) GetPRStatus(ctx context.Context, prURL string) (*PRStatus, error) { + callInfo := struct { + Ctx context.Context + PrURL string + }{ + Ctx: ctx, + PrURL: prURL, + } + mock.lockGetPRStatus.Lock() + mock.calls.GetPRStatus = append(mock.calls.GetPRStatus, callInfo) + mock.lockGetPRStatus.Unlock() + if mock.GetPRStatusFunc == nil { + var ( + pRStatusOut *PRStatus + errOut error + ) + return pRStatusOut, errOut + } + return mock.GetPRStatusFunc(ctx, prURL) +} + +// GetPRStatusCalls gets all the calls that were made to GetPRStatus. +// Check the length with: +// +// len(mockedClientInterface.GetPRStatusCalls()) +func (mock *GitHubClientMock) GetPRStatusCalls() []struct { + Ctx context.Context + PrURL string +} { + var calls []struct { + Ctx context.Context + PrURL string + } + mock.lockGetPRStatus.RLock() + calls = mock.calls.GetPRStatus + mock.lockGetPRStatus.RUnlock() + return calls +} + +// PostPRComment calls PostPRCommentFunc. +func (mock *GitHubClientMock) PostPRComment(ctx context.Context, prURL string, body string) error { + callInfo := struct { + Ctx context.Context + PrURL string + Body string + }{ + Ctx: ctx, + PrURL: prURL, + Body: body, + } + mock.lockPostPRComment.Lock() + mock.calls.PostPRComment = append(mock.calls.PostPRComment, callInfo) + mock.lockPostPRComment.Unlock() + if mock.PostPRCommentFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.PostPRCommentFunc(ctx, prURL, body) +} + +// PostPRCommentCalls gets all the calls that were made to PostPRComment. +// Check the length with: +// +// len(mockedClientInterface.PostPRCommentCalls()) +func (mock *GitHubClientMock) PostPRCommentCalls() []struct { + Ctx context.Context + PrURL string + Body string +} { + var calls []struct { + Ctx context.Context + PrURL string + Body string + } + mock.lockPostPRComment.RLock() + calls = mock.calls.PostPRComment + mock.lockPostPRComment.RUnlock() + return calls +} + +// PostReplyToComment calls PostReplyToCommentFunc. +func (mock *GitHubClientMock) PostReplyToComment(ctx context.Context, prURL string, commentID int, body string) error { + callInfo := struct { + Ctx context.Context + PrURL string + CommentID int + Body string + }{ + Ctx: ctx, + PrURL: prURL, + CommentID: commentID, + Body: body, + } + mock.lockPostReplyToComment.Lock() + mock.calls.PostReplyToComment = append(mock.calls.PostReplyToComment, callInfo) + mock.lockPostReplyToComment.Unlock() + if mock.PostReplyToCommentFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.PostReplyToCommentFunc(ctx, prURL, commentID, body) +} + +// PostReplyToCommentCalls gets all the calls that were made to PostReplyToComment. +// Check the length with: +// +// len(mockedClientInterface.PostReplyToCommentCalls()) +func (mock *GitHubClientMock) PostReplyToCommentCalls() []struct { + Ctx context.Context + PrURL string + CommentID int + Body string +} { + var calls []struct { + Ctx context.Context + PrURL string + CommentID int + Body string + } + mock.lockPostReplyToComment.RLock() + calls = mock.calls.PostReplyToComment + mock.lockPostReplyToComment.RUnlock() + return calls +} + +// PostReviewReply calls PostReviewReplyFunc. +func (mock *GitHubClientMock) PostReviewReply(ctx context.Context, prURL string, reviewCommentID int, body string) error { + callInfo := struct { + Ctx context.Context + PrURL string + ReviewCommentID int + Body string + }{ + Ctx: ctx, + PrURL: prURL, + ReviewCommentID: reviewCommentID, + Body: body, + } + mock.lockPostReviewReply.Lock() + mock.calls.PostReviewReply = append(mock.calls.PostReviewReply, callInfo) + mock.lockPostReviewReply.Unlock() + if mock.PostReviewReplyFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.PostReviewReplyFunc(ctx, prURL, reviewCommentID, body) +} + +// PostReviewReplyCalls gets all the calls that were made to PostReviewReply. +// Check the length with: +// +// len(mockedClientInterface.PostReviewReplyCalls()) +func (mock *GitHubClientMock) PostReviewReplyCalls() []struct { + Ctx context.Context + PrURL string + ReviewCommentID int + Body string +} { + var calls []struct { + Ctx context.Context + PrURL string + ReviewCommentID int + Body string + } + mock.lockPostReviewReply.RLock() + calls = mock.calls.PostReviewReply + mock.lockPostReviewReply.RUnlock() + return calls +} + +// ResolveReviewThread calls ResolveReviewThreadFunc. +func (mock *GitHubClientMock) ResolveReviewThread(ctx context.Context, prURL string, commentID int) error { + callInfo := struct { + Ctx context.Context + PrURL string + CommentID int + }{ + Ctx: ctx, + PrURL: prURL, + CommentID: commentID, + } + mock.lockResolveReviewThread.Lock() + mock.calls.ResolveReviewThread = append(mock.calls.ResolveReviewThread, callInfo) + mock.lockResolveReviewThread.Unlock() + if mock.ResolveReviewThreadFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.ResolveReviewThreadFunc(ctx, prURL, commentID) +} + +// ResolveReviewThreadCalls gets all the calls that were made to ResolveReviewThread. +// Check the length with: +// +// len(mockedClientInterface.ResolveReviewThreadCalls()) +func (mock *GitHubClientMock) ResolveReviewThreadCalls() []struct { + Ctx context.Context + PrURL string + CommentID int +} { + var calls []struct { + Ctx context.Context + PrURL string + CommentID int + } + mock.lockResolveReviewThread.RLock() + calls = mock.calls.ResolveReviewThread + mock.lockResolveReviewThread.RUnlock() + return calls +} diff --git a/internal/linear/client.go b/internal/linear/client.go index 0beab3a2..09090b5e 100644 --- a/internal/linear/client.go +++ b/internal/linear/client.go @@ -1,5 +1,7 @@ package linear +//go:generate moq -stub -out linear_mock.go . ClientInterface:LinearClientMock + import ( "bytes" "context" diff --git a/internal/linear/client_test.go b/internal/linear/client_test.go index 1f39636f..202b2b21 100644 --- a/internal/linear/client_test.go +++ b/internal/linear/client_test.go @@ -2,6 +2,8 @@ package linear import ( "testing" + + "github.com/stretchr/testify/require" ) func TestParseIssueIDOrURL(t *testing.T) { @@ -66,13 +68,14 @@ func TestParseIssueIDOrURL(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := ParseIssueIDOrURL(tt.input) - if (err != nil) != tt.wantErr { - t.Errorf("ParseIssueIDOrURL() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != tt.want { - t.Errorf("ParseIssueIDOrURL() = %v, want %v", got, tt.want) + + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) } + + require.Equal(t, tt.want, got) }) } } @@ -98,12 +101,12 @@ func TestNewClient(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { client, err := NewClient(tt.apiKey) - if (err != nil) != tt.wantErr { - t.Errorf("NewClient() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !tt.wantErr && client == nil { - t.Error("NewClient() returned nil client without error") + + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + require.NotNil(t, client, "NewClient() returned nil client without error") } }) } diff --git a/internal/linear/fetcher_test.go b/internal/linear/fetcher_test.go index 8b01c41b..5a9a0dc7 100644 --- a/internal/linear/fetcher_test.go +++ b/internal/linear/fetcher_test.go @@ -2,25 +2,12 @@ package linear import ( "context" - "errors" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -// MockClient is a mock implementation of the Linear client for testing -type MockClient struct { - GetIssueFunc func(ctx context.Context, issueID string) (*Issue, error) -} - -func (m *MockClient) GetIssue(ctx context.Context, issueID string) (*Issue, error) { - if m.GetIssueFunc != nil { - return m.GetIssueFunc(ctx, issueID) - } - return nil, errors.New("not implemented") -} - func TestFetcherErrorHandling(t *testing.T) { ctx := context.Background() diff --git a/internal/linear/integration_test.go b/internal/linear/integration_test.go index 1ccf764f..667e8780 100644 --- a/internal/linear/integration_test.go +++ b/internal/linear/integration_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/newhook/co/internal/linear" + "github.com/stretchr/testify/require" ) // TestLinearImportIntegration demonstrates the complete Linear import workflow @@ -17,23 +18,16 @@ func TestLinearImportIntegration(t *testing.T) { // Initialize the fetcher with API key and beads directory apiKey := "lin_api_test_key" // In production, get from env or config - beadsDir := "/path/to/beads" // In production, auto-detect or get from config + beadsDir := "/path/to/beads" // In production, auto-detect or get from config fetcher, err := linear.NewFetcher(apiKey, beadsDir) - if err != nil { - t.Fatalf("Failed to create fetcher: %v", err) - } + require.NoError(t, err, "Failed to create fetcher") // Example 1: Simple import of a single issue t.Run("ImportSingleIssue", func(t *testing.T) { result, err := fetcher.FetchAndImport(ctx, "ENG-123", nil) - if err != nil { - t.Fatalf("Failed to import issue: %v", err) - } - - if !result.Success { - t.Fatalf("Import failed: %s", result.SkipReason) - } + require.NoError(t, err, "Failed to import issue") + require.True(t, result.Success, "Import failed: %s", result.SkipReason) t.Logf("Imported Linear issue %s as bead %s", result.LinearID, result.BeadID) }) @@ -48,9 +42,7 @@ func TestLinearImportIntegration(t *testing.T) { } result, err := fetcher.FetchAndImport(ctx, "ENG-456", opts) - if err != nil { - t.Fatalf("Failed to import issue: %v", err) - } + require.NoError(t, err, "Failed to import issue") if result.SkipReason == "already imported" { t.Logf("Issue already imported as bead %s", result.BeadID) @@ -63,9 +55,7 @@ func TestLinearImportIntegration(t *testing.T) { t.Run("ImportByURL", func(t *testing.T) { url := "https://linear.app/company/issue/ENG-789/feature-title" result, err := fetcher.FetchAndImport(ctx, url, nil) - if err != nil { - t.Fatalf("Failed to import issue: %v", err) - } + require.NoError(t, err, "Failed to import issue") if result.Success { t.Logf("Imported Linear issue from URL: %s -> bead %s", result.LinearURL, result.BeadID) @@ -82,9 +72,7 @@ func TestLinearImportIntegration(t *testing.T) { } results, err := fetcher.FetchBatch(ctx, issues, opts) - if err != nil { - t.Fatalf("Batch import failed: %v", err) - } + require.NoError(t, err, "Batch import failed") successCount := 0 for _, result := range results { @@ -108,9 +96,7 @@ func TestLinearImportIntegration(t *testing.T) { } result, err := fetcher.FetchAndImport(ctx, "ENG-123", opts) - if err != nil { - t.Fatalf("Failed to update issue: %v", err) - } + require.NoError(t, err, "Failed to update issue") if result.SkipReason == "updated existing bead" { t.Logf("Updated existing bead %s with latest data from Linear", result.BeadID) @@ -128,9 +114,7 @@ func TestLinearImportIntegration(t *testing.T) { } result, err := fetcher.FetchAndImport(ctx, "ENG-999", opts) - if err != nil { - t.Fatalf("Failed to import: %v", err) - } + require.NoError(t, err, "Failed to import") if result.SkipReason != "" { t.Logf("Skipped: %s", result.SkipReason) @@ -146,9 +130,7 @@ func TestLinearImportIntegration(t *testing.T) { } result, err := fetcher.FetchAndImport(ctx, "ENG-777", opts) - if err != nil { - t.Fatalf("Failed dry run: %v", err) - } + require.NoError(t, err, "Failed dry run") if result.SkipReason == "dry run" { t.Logf("Dry run successful - would import Linear issue %s", result.LinearID) @@ -165,9 +147,7 @@ func TestLinearImportErrorHandling(t *testing.T) { // Example: Invalid API key t.Run("InvalidAPIKey", func(t *testing.T) { _, err := linear.NewFetcher("invalid_key", "/path/to/beads") - if err == nil { - t.Fatal("Expected error for invalid API key") - } + require.Error(t, err, "Expected error for invalid API key") t.Logf("Got expected error: %v", err) }) @@ -175,9 +155,7 @@ func TestLinearImportErrorHandling(t *testing.T) { t.Run("InvalidIssueID", func(t *testing.T) { fetcher, _ := linear.NewFetcher("valid_key", "/path/to/beads") result, err := fetcher.FetchAndImport(ctx, "INVALID-999", nil) - if err == nil && result.Success { - t.Fatal("Expected error for invalid issue ID") - } + require.True(t, err != nil || !result.Success, "Expected error for invalid issue ID") if result.Error != nil { t.Logf("Got expected error: %v", result.Error) } diff --git a/internal/linear/linear_mock.go b/internal/linear/linear_mock.go new file mode 100644 index 00000000..720d3b8e --- /dev/null +++ b/internal/linear/linear_mock.go @@ -0,0 +1,253 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package linear + +import ( + "context" + "sync" +) + +// Ensure, that LinearClientMock does implement ClientInterface. +// If this is not the case, regenerate this file with moq. +var _ ClientInterface = &LinearClientMock{} + +// LinearClientMock is a mock implementation of ClientInterface. +// +// func TestSomethingThatUsesClientInterface(t *testing.T) { +// +// // make and configure a mocked ClientInterface +// mockedClientInterface := &LinearClientMock{ +// GetIssueFunc: func(ctx context.Context, issueIDOrURL string) (*Issue, error) { +// panic("mock out the GetIssue method") +// }, +// GetIssueCommentsFunc: func(ctx context.Context, issueID string) ([]Comment, error) { +// panic("mock out the GetIssueComments method") +// }, +// ListIssuesFunc: func(ctx context.Context, filters map[string]any) ([]*Issue, error) { +// panic("mock out the ListIssues method") +// }, +// SearchIssuesFunc: func(ctx context.Context, searchQuery string, filters map[string]any) ([]*Issue, error) { +// panic("mock out the SearchIssues method") +// }, +// } +// +// // use mockedClientInterface in code that requires ClientInterface +// // and then make assertions. +// +// } +type LinearClientMock struct { + // GetIssueFunc mocks the GetIssue method. + GetIssueFunc func(ctx context.Context, issueIDOrURL string) (*Issue, error) + + // GetIssueCommentsFunc mocks the GetIssueComments method. + GetIssueCommentsFunc func(ctx context.Context, issueID string) ([]Comment, error) + + // ListIssuesFunc mocks the ListIssues method. + ListIssuesFunc func(ctx context.Context, filters map[string]any) ([]*Issue, error) + + // SearchIssuesFunc mocks the SearchIssues method. + SearchIssuesFunc func(ctx context.Context, searchQuery string, filters map[string]any) ([]*Issue, error) + + // calls tracks calls to the methods. + calls struct { + // GetIssue holds details about calls to the GetIssue method. + GetIssue []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // IssueIDOrURL is the issueIDOrURL argument value. + IssueIDOrURL string + } + // GetIssueComments holds details about calls to the GetIssueComments method. + GetIssueComments []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // IssueID is the issueID argument value. + IssueID string + } + // ListIssues holds details about calls to the ListIssues method. + ListIssues []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Filters is the filters argument value. + Filters map[string]any + } + // SearchIssues holds details about calls to the SearchIssues method. + SearchIssues []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // SearchQuery is the searchQuery argument value. + SearchQuery string + // Filters is the filters argument value. + Filters map[string]any + } + } + lockGetIssue sync.RWMutex + lockGetIssueComments sync.RWMutex + lockListIssues sync.RWMutex + lockSearchIssues sync.RWMutex +} + +// GetIssue calls GetIssueFunc. +func (mock *LinearClientMock) GetIssue(ctx context.Context, issueIDOrURL string) (*Issue, error) { + callInfo := struct { + Ctx context.Context + IssueIDOrURL string + }{ + Ctx: ctx, + IssueIDOrURL: issueIDOrURL, + } + mock.lockGetIssue.Lock() + mock.calls.GetIssue = append(mock.calls.GetIssue, callInfo) + mock.lockGetIssue.Unlock() + if mock.GetIssueFunc == nil { + var ( + issueOut *Issue + errOut error + ) + return issueOut, errOut + } + return mock.GetIssueFunc(ctx, issueIDOrURL) +} + +// GetIssueCalls gets all the calls that were made to GetIssue. +// Check the length with: +// +// len(mockedClientInterface.GetIssueCalls()) +func (mock *LinearClientMock) GetIssueCalls() []struct { + Ctx context.Context + IssueIDOrURL string +} { + var calls []struct { + Ctx context.Context + IssueIDOrURL string + } + mock.lockGetIssue.RLock() + calls = mock.calls.GetIssue + mock.lockGetIssue.RUnlock() + return calls +} + +// GetIssueComments calls GetIssueCommentsFunc. +func (mock *LinearClientMock) GetIssueComments(ctx context.Context, issueID string) ([]Comment, error) { + callInfo := struct { + Ctx context.Context + IssueID string + }{ + Ctx: ctx, + IssueID: issueID, + } + mock.lockGetIssueComments.Lock() + mock.calls.GetIssueComments = append(mock.calls.GetIssueComments, callInfo) + mock.lockGetIssueComments.Unlock() + if mock.GetIssueCommentsFunc == nil { + var ( + commentsOut []Comment + errOut error + ) + return commentsOut, errOut + } + return mock.GetIssueCommentsFunc(ctx, issueID) +} + +// GetIssueCommentsCalls gets all the calls that were made to GetIssueComments. +// Check the length with: +// +// len(mockedClientInterface.GetIssueCommentsCalls()) +func (mock *LinearClientMock) GetIssueCommentsCalls() []struct { + Ctx context.Context + IssueID string +} { + var calls []struct { + Ctx context.Context + IssueID string + } + mock.lockGetIssueComments.RLock() + calls = mock.calls.GetIssueComments + mock.lockGetIssueComments.RUnlock() + return calls +} + +// ListIssues calls ListIssuesFunc. +func (mock *LinearClientMock) ListIssues(ctx context.Context, filters map[string]any) ([]*Issue, error) { + callInfo := struct { + Ctx context.Context + Filters map[string]any + }{ + Ctx: ctx, + Filters: filters, + } + mock.lockListIssues.Lock() + mock.calls.ListIssues = append(mock.calls.ListIssues, callInfo) + mock.lockListIssues.Unlock() + if mock.ListIssuesFunc == nil { + var ( + issuesOut []*Issue + errOut error + ) + return issuesOut, errOut + } + return mock.ListIssuesFunc(ctx, filters) +} + +// ListIssuesCalls gets all the calls that were made to ListIssues. +// Check the length with: +// +// len(mockedClientInterface.ListIssuesCalls()) +func (mock *LinearClientMock) ListIssuesCalls() []struct { + Ctx context.Context + Filters map[string]any +} { + var calls []struct { + Ctx context.Context + Filters map[string]any + } + mock.lockListIssues.RLock() + calls = mock.calls.ListIssues + mock.lockListIssues.RUnlock() + return calls +} + +// SearchIssues calls SearchIssuesFunc. +func (mock *LinearClientMock) SearchIssues(ctx context.Context, searchQuery string, filters map[string]any) ([]*Issue, error) { + callInfo := struct { + Ctx context.Context + SearchQuery string + Filters map[string]any + }{ + Ctx: ctx, + SearchQuery: searchQuery, + Filters: filters, + } + mock.lockSearchIssues.Lock() + mock.calls.SearchIssues = append(mock.calls.SearchIssues, callInfo) + mock.lockSearchIssues.Unlock() + if mock.SearchIssuesFunc == nil { + var ( + issuesOut []*Issue + errOut error + ) + return issuesOut, errOut + } + return mock.SearchIssuesFunc(ctx, searchQuery, filters) +} + +// SearchIssuesCalls gets all the calls that were made to SearchIssues. +// Check the length with: +// +// len(mockedClientInterface.SearchIssuesCalls()) +func (mock *LinearClientMock) SearchIssuesCalls() []struct { + Ctx context.Context + SearchQuery string + Filters map[string]any +} { + var calls []struct { + Ctx context.Context + SearchQuery string + Filters map[string]any + } + mock.lockSearchIssues.RLock() + calls = mock.calls.SearchIssues + mock.lockSearchIssues.RUnlock() + return calls +} diff --git a/internal/linear/mapper_test.go b/internal/linear/mapper_test.go index 3294ec72..902a7fd6 100644 --- a/internal/linear/mapper_test.go +++ b/internal/linear/mapper_test.go @@ -1,7 +1,10 @@ package linear import ( + "strings" "testing" + + "github.com/stretchr/testify/require" ) func TestMapStatus(t *testing.T) { @@ -45,9 +48,7 @@ func TestMapStatus(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := MapStatus(tt.state) - if got != tt.want { - t.Errorf("MapStatus() = %v, want %v", got, tt.want) - } + require.Equal(t, tt.want, got) }) } } @@ -93,9 +94,7 @@ func TestMapPriority(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := MapPriority(tt.priority) - if got != tt.want { - t.Errorf("MapPriority() = %v, want %v", got, tt.want) - } + require.Equal(t, tt.want, got) }) } } @@ -164,9 +163,7 @@ func TestMapType(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := MapType(tt.issue) - if got != tt.want { - t.Errorf("MapType() = %v, want %v", got, tt.want) - } + require.Equal(t, tt.want, got) }) } } @@ -188,36 +185,16 @@ func TestMapIssueToBeadCreate(t *testing.T) { opts := MapIssueToBeadCreate(issue) - if opts.Title != "Fix authentication bug" { - t.Errorf("Title = %v, want %v", opts.Title, "Fix authentication bug") - } - if opts.Type != "bug" { - t.Errorf("Type = %v, want %v", opts.Type, "bug") - } - if opts.Priority != "P1" { - t.Errorf("Priority = %v, want %v", opts.Priority, "P1") - } - if opts.Status != "in_progress" { - t.Errorf("Status = %v, want %v", opts.Status, "in_progress") - } - if opts.Assignee != "john@example.com" { - t.Errorf("Assignee = %v, want %v", opts.Assignee, "john@example.com") - } - if len(opts.Labels) != 2 { - t.Errorf("Labels count = %v, want %v", len(opts.Labels), 2) - } - if opts.Metadata["linear_id"] != "ENG-123" { - t.Errorf("Metadata linear_id = %v, want %v", opts.Metadata["linear_id"], "ENG-123") - } - if opts.Metadata["linear_url"] != "https://linear.app/team/issue/ENG-123" { - t.Errorf("Metadata linear_url = %v, want %v", opts.Metadata["linear_url"], "https://linear.app/team/issue/ENG-123") - } - if opts.Metadata["linear_project"] != "Q1 Features" { - t.Errorf("Metadata linear_project = %v, want %v", opts.Metadata["linear_project"], "Q1 Features") - } - if opts.Metadata["linear_estimate"] != "3.5" { - t.Errorf("Metadata linear_estimate = %v, want %v", opts.Metadata["linear_estimate"], "3.5") - } + require.Equal(t, "Fix authentication bug", opts.Title) + require.Equal(t, "bug", opts.Type) + require.Equal(t, "P1", opts.Priority) + require.Equal(t, "in_progress", opts.Status) + require.Equal(t, "john@example.com", opts.Assignee) + require.Len(t, opts.Labels, 2) + require.Equal(t, "ENG-123", opts.Metadata["linear_id"]) + require.Equal(t, "https://linear.app/team/issue/ENG-123", opts.Metadata["linear_url"]) + require.Equal(t, "Q1 Features", opts.Metadata["linear_project"]) + require.Equal(t, "3.5", opts.Metadata["linear_estimate"]) } func TestFormatBeadDescription(t *testing.T) { @@ -236,38 +213,11 @@ func TestFormatBeadDescription(t *testing.T) { desc := FormatBeadDescription(issue) // Check that it contains key elements - if !contains(desc, "Original description") { - t.Error("Description should contain original description") - } - if !contains(desc, "ENG-456") { - t.Error("Description should contain Linear ID") - } - if !contains(desc, "https://linear.app/team/issue/ENG-456") { - t.Error("Description should contain URL") - } - if !contains(desc, "In Progress") { - t.Error("Description should contain state name") - } - if !contains(desc, "Backend") { - t.Error("Description should contain project name") - } - if !contains(desc, "2.0") { - t.Error("Description should contain estimate") - } - if !contains(desc, "Jane Smith") { - t.Error("Description should contain assignee name") - } -} - -func contains(s, substr string) bool { - return len(s) > 0 && len(substr) > 0 && s != substr && (s == substr || len(s) > len(substr) && (s[:len(substr)] == substr || s[len(s)-len(substr):] == substr || containsInMiddle(s, substr))) -} - -func containsInMiddle(s, substr string) bool { - for i := 0; i <= len(s)-len(substr); i++ { - if s[i:i+len(substr)] == substr { - return true - } - } - return false + require.True(t, strings.Contains(desc, "Original description"), "Description should contain original description") + require.True(t, strings.Contains(desc, "ENG-456"), "Description should contain Linear ID") + require.True(t, strings.Contains(desc, "https://linear.app/team/issue/ENG-456"), "Description should contain URL") + require.True(t, strings.Contains(desc, "In Progress"), "Description should contain state name") + require.True(t, strings.Contains(desc, "Backend"), "Description should contain project name") + require.True(t, strings.Contains(desc, "2.0"), "Description should contain estimate") + require.True(t, strings.Contains(desc, "Jane Smith"), "Description should contain assignee name") } diff --git a/internal/mise/mise.go b/internal/mise/mise.go index 6141a04c..1b22fe66 100644 --- a/internal/mise/mise.go +++ b/internal/mise/mise.go @@ -1,5 +1,7 @@ package mise +//go:generate moq -stub -out mise_mock.go . Operations:MiseOperationsMock + import ( "fmt" "io" diff --git a/internal/mise/mise_mock.go b/internal/mise/mise_mock.go new file mode 100644 index 00000000..c7b2a696 --- /dev/null +++ b/internal/mise/mise_mock.go @@ -0,0 +1,386 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package mise + +import ( + "io" + "sync" +) + +// Ensure, that MiseOperationsMock does implement Operations. +// If this is not the case, regenerate this file with moq. +var _ Operations = &MiseOperationsMock{} + +// MiseOperationsMock is a mock implementation of Operations. +// +// func TestSomethingThatUsesOperations(t *testing.T) { +// +// // make and configure a mocked Operations +// mockedOperations := &MiseOperationsMock{ +// ExecFunc: func(command string, args ...string) ([]byte, error) { +// panic("mock out the Exec method") +// }, +// HasTaskFunc: func(taskName string) bool { +// panic("mock out the HasTask method") +// }, +// InitializeFunc: func() error { +// panic("mock out the Initialize method") +// }, +// InitializeWithOutputFunc: func(w io.Writer) error { +// panic("mock out the InitializeWithOutput method") +// }, +// InstallFunc: func() error { +// panic("mock out the Install method") +// }, +// IsManagedFunc: func() bool { +// panic("mock out the IsManaged method") +// }, +// RunTaskFunc: func(taskName string) error { +// panic("mock out the RunTask method") +// }, +// TrustFunc: func() error { +// panic("mock out the Trust method") +// }, +// } +// +// // use mockedOperations in code that requires Operations +// // and then make assertions. +// +// } +type MiseOperationsMock struct { + // ExecFunc mocks the Exec method. + ExecFunc func(command string, args ...string) ([]byte, error) + + // HasTaskFunc mocks the HasTask method. + HasTaskFunc func(taskName string) bool + + // InitializeFunc mocks the Initialize method. + InitializeFunc func() error + + // InitializeWithOutputFunc mocks the InitializeWithOutput method. + InitializeWithOutputFunc func(w io.Writer) error + + // InstallFunc mocks the Install method. + InstallFunc func() error + + // IsManagedFunc mocks the IsManaged method. + IsManagedFunc func() bool + + // RunTaskFunc mocks the RunTask method. + RunTaskFunc func(taskName string) error + + // TrustFunc mocks the Trust method. + TrustFunc func() error + + // calls tracks calls to the methods. + calls struct { + // Exec holds details about calls to the Exec method. + Exec []struct { + // Command is the command argument value. + Command string + // Args is the args argument value. + Args []string + } + // HasTask holds details about calls to the HasTask method. + HasTask []struct { + // TaskName is the taskName argument value. + TaskName string + } + // Initialize holds details about calls to the Initialize method. + Initialize []struct { + } + // InitializeWithOutput holds details about calls to the InitializeWithOutput method. + InitializeWithOutput []struct { + // W is the w argument value. + W io.Writer + } + // Install holds details about calls to the Install method. + Install []struct { + } + // IsManaged holds details about calls to the IsManaged method. + IsManaged []struct { + } + // RunTask holds details about calls to the RunTask method. + RunTask []struct { + // TaskName is the taskName argument value. + TaskName string + } + // Trust holds details about calls to the Trust method. + Trust []struct { + } + } + lockExec sync.RWMutex + lockHasTask sync.RWMutex + lockInitialize sync.RWMutex + lockInitializeWithOutput sync.RWMutex + lockInstall sync.RWMutex + lockIsManaged sync.RWMutex + lockRunTask sync.RWMutex + lockTrust sync.RWMutex +} + +// Exec calls ExecFunc. +func (mock *MiseOperationsMock) Exec(command string, args ...string) ([]byte, error) { + callInfo := struct { + Command string + Args []string + }{ + Command: command, + Args: args, + } + mock.lockExec.Lock() + mock.calls.Exec = append(mock.calls.Exec, callInfo) + mock.lockExec.Unlock() + if mock.ExecFunc == nil { + var ( + bytesOut []byte + errOut error + ) + return bytesOut, errOut + } + return mock.ExecFunc(command, args...) +} + +// ExecCalls gets all the calls that were made to Exec. +// Check the length with: +// +// len(mockedOperations.ExecCalls()) +func (mock *MiseOperationsMock) ExecCalls() []struct { + Command string + Args []string +} { + var calls []struct { + Command string + Args []string + } + mock.lockExec.RLock() + calls = mock.calls.Exec + mock.lockExec.RUnlock() + return calls +} + +// HasTask calls HasTaskFunc. +func (mock *MiseOperationsMock) HasTask(taskName string) bool { + callInfo := struct { + TaskName string + }{ + TaskName: taskName, + } + mock.lockHasTask.Lock() + mock.calls.HasTask = append(mock.calls.HasTask, callInfo) + mock.lockHasTask.Unlock() + if mock.HasTaskFunc == nil { + var ( + bOut bool + ) + return bOut + } + return mock.HasTaskFunc(taskName) +} + +// HasTaskCalls gets all the calls that were made to HasTask. +// Check the length with: +// +// len(mockedOperations.HasTaskCalls()) +func (mock *MiseOperationsMock) HasTaskCalls() []struct { + TaskName string +} { + var calls []struct { + TaskName string + } + mock.lockHasTask.RLock() + calls = mock.calls.HasTask + mock.lockHasTask.RUnlock() + return calls +} + +// Initialize calls InitializeFunc. +func (mock *MiseOperationsMock) Initialize() error { + callInfo := struct { + }{} + mock.lockInitialize.Lock() + mock.calls.Initialize = append(mock.calls.Initialize, callInfo) + mock.lockInitialize.Unlock() + if mock.InitializeFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.InitializeFunc() +} + +// InitializeCalls gets all the calls that were made to Initialize. +// Check the length with: +// +// len(mockedOperations.InitializeCalls()) +func (mock *MiseOperationsMock) InitializeCalls() []struct { +} { + var calls []struct { + } + mock.lockInitialize.RLock() + calls = mock.calls.Initialize + mock.lockInitialize.RUnlock() + return calls +} + +// InitializeWithOutput calls InitializeWithOutputFunc. +func (mock *MiseOperationsMock) InitializeWithOutput(w io.Writer) error { + callInfo := struct { + W io.Writer + }{ + W: w, + } + mock.lockInitializeWithOutput.Lock() + mock.calls.InitializeWithOutput = append(mock.calls.InitializeWithOutput, callInfo) + mock.lockInitializeWithOutput.Unlock() + if mock.InitializeWithOutputFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.InitializeWithOutputFunc(w) +} + +// InitializeWithOutputCalls gets all the calls that were made to InitializeWithOutput. +// Check the length with: +// +// len(mockedOperations.InitializeWithOutputCalls()) +func (mock *MiseOperationsMock) InitializeWithOutputCalls() []struct { + W io.Writer +} { + var calls []struct { + W io.Writer + } + mock.lockInitializeWithOutput.RLock() + calls = mock.calls.InitializeWithOutput + mock.lockInitializeWithOutput.RUnlock() + return calls +} + +// Install calls InstallFunc. +func (mock *MiseOperationsMock) Install() error { + callInfo := struct { + }{} + mock.lockInstall.Lock() + mock.calls.Install = append(mock.calls.Install, callInfo) + mock.lockInstall.Unlock() + if mock.InstallFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.InstallFunc() +} + +// InstallCalls gets all the calls that were made to Install. +// Check the length with: +// +// len(mockedOperations.InstallCalls()) +func (mock *MiseOperationsMock) InstallCalls() []struct { +} { + var calls []struct { + } + mock.lockInstall.RLock() + calls = mock.calls.Install + mock.lockInstall.RUnlock() + return calls +} + +// IsManaged calls IsManagedFunc. +func (mock *MiseOperationsMock) IsManaged() bool { + callInfo := struct { + }{} + mock.lockIsManaged.Lock() + mock.calls.IsManaged = append(mock.calls.IsManaged, callInfo) + mock.lockIsManaged.Unlock() + if mock.IsManagedFunc == nil { + var ( + bOut bool + ) + return bOut + } + return mock.IsManagedFunc() +} + +// IsManagedCalls gets all the calls that were made to IsManaged. +// Check the length with: +// +// len(mockedOperations.IsManagedCalls()) +func (mock *MiseOperationsMock) IsManagedCalls() []struct { +} { + var calls []struct { + } + mock.lockIsManaged.RLock() + calls = mock.calls.IsManaged + mock.lockIsManaged.RUnlock() + return calls +} + +// RunTask calls RunTaskFunc. +func (mock *MiseOperationsMock) RunTask(taskName string) error { + callInfo := struct { + TaskName string + }{ + TaskName: taskName, + } + mock.lockRunTask.Lock() + mock.calls.RunTask = append(mock.calls.RunTask, callInfo) + mock.lockRunTask.Unlock() + if mock.RunTaskFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.RunTaskFunc(taskName) +} + +// RunTaskCalls gets all the calls that were made to RunTask. +// Check the length with: +// +// len(mockedOperations.RunTaskCalls()) +func (mock *MiseOperationsMock) RunTaskCalls() []struct { + TaskName string +} { + var calls []struct { + TaskName string + } + mock.lockRunTask.RLock() + calls = mock.calls.RunTask + mock.lockRunTask.RUnlock() + return calls +} + +// Trust calls TrustFunc. +func (mock *MiseOperationsMock) Trust() error { + callInfo := struct { + }{} + mock.lockTrust.Lock() + mock.calls.Trust = append(mock.calls.Trust, callInfo) + mock.lockTrust.Unlock() + if mock.TrustFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.TrustFunc() +} + +// TrustCalls gets all the calls that were made to Trust. +// Check the length with: +// +// len(mockedOperations.TrustCalls()) +func (mock *MiseOperationsMock) TrustCalls() []struct { +} { + var calls []struct { + } + mock.lockTrust.RLock() + calls = mock.calls.Trust + mock.lockTrust.RUnlock() + return calls +} diff --git a/internal/mise/mise_test.go b/internal/mise/mise_test.go new file mode 100644 index 00000000..7b56f3f0 --- /dev/null +++ b/internal/mise/mise_test.go @@ -0,0 +1,160 @@ +package mise + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewOperations(t *testing.T) { + ops := NewOperations("/some/dir") + require.NotNil(t, ops, "NewOperations returned nil") + + // Verify it returns a cliOperations + cli, ok := ops.(*cliOperations) + require.True(t, ok, "NewOperations should return *cliOperations") + require.Equal(t, "/some/dir", cli.dir) +} + +func TestCLIOperationsImplementsInterface(t *testing.T) { + // Compile-time check that cliOperations implements Operations + var _ Operations = (*cliOperations)(nil) +} + +func TestFindConfigFile(t *testing.T) { + // Create a temp directory for testing + tempDir, err := os.MkdirTemp("", "mise-test-*") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + + tests := []struct { + name string + setup func() // Create files before test + expected string + }{ + { + name: "no config file", + setup: func() {}, + expected: "", + }, + { + name: "mise.toml exists", + setup: func() { + os.WriteFile(filepath.Join(tempDir, "mise.toml"), []byte(""), 0644) + }, + expected: "mise.toml", + }, + { + name: ".mise.toml exists", + setup: func() { + os.WriteFile(filepath.Join(tempDir, ".mise.toml"), []byte(""), 0644) + }, + expected: ".mise.toml", + }, + { + name: ".tool-versions exists", + setup: func() { + os.WriteFile(filepath.Join(tempDir, ".tool-versions"), []byte(""), 0644) + }, + expected: ".tool-versions", + }, + { + name: ".mise/config.toml exists", + setup: func() { + os.MkdirAll(filepath.Join(tempDir, ".mise"), 0750) + os.WriteFile(filepath.Join(tempDir, ".mise", "config.toml"), []byte(""), 0644) + }, + expected: ".mise/config.toml", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Clean up any files from previous tests + os.RemoveAll(filepath.Join(tempDir, ".mise.toml")) + os.RemoveAll(filepath.Join(tempDir, "mise.toml")) + os.RemoveAll(filepath.Join(tempDir, ".mise")) + os.RemoveAll(filepath.Join(tempDir, ".tool-versions")) + + tt.setup() + + result := findConfigFile(tempDir) + require.Equal(t, tt.expected, result) + }) + } +} + +func TestFindConfigFile_Priority(t *testing.T) { + // Create a temp directory for testing + tempDir, err := os.MkdirTemp("", "mise-test-priority-*") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + + // Create multiple config files + os.WriteFile(filepath.Join(tempDir, ".mise.toml"), []byte(""), 0644) + os.WriteFile(filepath.Join(tempDir, "mise.toml"), []byte(""), 0644) + os.WriteFile(filepath.Join(tempDir, ".tool-versions"), []byte(""), 0644) + + // Should return first in order: .mise.toml + result := findConfigFile(tempDir) + require.Equal(t, ".mise.toml", result, "expected '.mise.toml' (first in order)") +} + +func TestIsManaged_PackageLevel(t *testing.T) { + // Create a temp directory for testing + tempDir, err := os.MkdirTemp("", "mise-test-managed-*") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + + // Test with no config file + require.False(t, IsManaged(tempDir), "expected IsManaged to return false when no config file exists") + + // Create a config file + os.WriteFile(filepath.Join(tempDir, ".mise.toml"), []byte(""), 0644) + + // Test with config file + require.True(t, IsManaged(tempDir), "expected IsManaged to return true when config file exists") +} + +func TestOperations_IsManaged(t *testing.T) { + // Create a temp directory for testing + tempDir, err := os.MkdirTemp("", "mise-test-ops-managed-*") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + + ops := NewOperations(tempDir) + + // Test with no config file + require.False(t, ops.IsManaged(), "expected IsManaged to return false when no config file exists") + + // Create a config file + os.WriteFile(filepath.Join(tempDir, ".mise.toml"), []byte(""), 0644) + + // Test with config file + require.True(t, ops.IsManaged(), "expected IsManaged to return true when config file exists") +} + +func TestConfigFiles_AllVariants(t *testing.T) { + // Verify configFiles contains expected entries + expected := []string{ + ".mise.toml", + "mise.toml", + ".mise/config.toml", + ".tool-versions", + } + + require.Len(t, configFiles, len(expected)) + + for _, exp := range expected { + found := false + for _, cf := range configFiles { + if cf == exp { + found = true + break + } + } + require.True(t, found, "expected configFiles to contain %q", exp) + } +} diff --git a/internal/process/escape_test.go b/internal/process/escape_test.go new file mode 100644 index 00000000..92613afa --- /dev/null +++ b/internal/process/escape_test.go @@ -0,0 +1,58 @@ +package process + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestEscapePattern(t *testing.T) { + tests := []struct { + name string + pattern string + want string + }{ + { + name: "simple pattern", + pattern: "simple", + want: "'simple'", + }, + { + name: "pattern with single quote", + pattern: "test'pattern", + want: "'test'\\''pattern'", + }, + { + name: "pattern with multiple single quotes", + pattern: "test'pattern'here", + want: "'test'\\''pattern'\\''here'", + }, + { + name: "pattern with special characters", + pattern: "test$pattern*here", + want: "'test$pattern*here'", + }, + { + name: "empty pattern", + pattern: "", + want: "''", + }, + { + name: "pattern with spaces", + pattern: "test pattern", + want: "'test pattern'", + }, + { + name: "pattern with newline", + pattern: "test\npattern", + want: "'test\npattern'", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := escapePattern(tt.pattern) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/internal/process/process.go b/internal/process/process.go index b3677844..0076c14a 100644 --- a/internal/process/process.go +++ b/internal/process/process.go @@ -1,6 +1,8 @@ // Package process provides cross-platform process detection utilities. package process +//go:generate moq -stub -out process_mock.go . ProcessLister ProcessKiller + import ( "context" "fmt" diff --git a/internal/process/process_mock.go b/internal/process/process_mock.go new file mode 100644 index 00000000..12a4aa8c --- /dev/null +++ b/internal/process/process_mock.go @@ -0,0 +1,154 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package process + +import ( + "context" + "sync" +) + +// Ensure, that ProcessListerMock does implement ProcessLister. +// If this is not the case, regenerate this file with moq. +var _ ProcessLister = &ProcessListerMock{} + +// ProcessListerMock is a mock implementation of ProcessLister. +// +// func TestSomethingThatUsesProcessLister(t *testing.T) { +// +// // make and configure a mocked ProcessLister +// mockedProcessLister := &ProcessListerMock{ +// GetProcessListFunc: func(ctx context.Context) ([]string, error) { +// panic("mock out the GetProcessList method") +// }, +// } +// +// // use mockedProcessLister in code that requires ProcessLister +// // and then make assertions. +// +// } +type ProcessListerMock struct { + // GetProcessListFunc mocks the GetProcessList method. + GetProcessListFunc func(ctx context.Context) ([]string, error) + + // calls tracks calls to the methods. + calls struct { + // GetProcessList holds details about calls to the GetProcessList method. + GetProcessList []struct { + // Ctx is the ctx argument value. + Ctx context.Context + } + } + lockGetProcessList sync.RWMutex +} + +// GetProcessList calls GetProcessListFunc. +func (mock *ProcessListerMock) GetProcessList(ctx context.Context) ([]string, error) { + callInfo := struct { + Ctx context.Context + }{ + Ctx: ctx, + } + mock.lockGetProcessList.Lock() + mock.calls.GetProcessList = append(mock.calls.GetProcessList, callInfo) + mock.lockGetProcessList.Unlock() + if mock.GetProcessListFunc == nil { + var ( + stringsOut []string + errOut error + ) + return stringsOut, errOut + } + return mock.GetProcessListFunc(ctx) +} + +// GetProcessListCalls gets all the calls that were made to GetProcessList. +// Check the length with: +// +// len(mockedProcessLister.GetProcessListCalls()) +func (mock *ProcessListerMock) GetProcessListCalls() []struct { + Ctx context.Context +} { + var calls []struct { + Ctx context.Context + } + mock.lockGetProcessList.RLock() + calls = mock.calls.GetProcessList + mock.lockGetProcessList.RUnlock() + return calls +} + +// Ensure, that ProcessKillerMock does implement ProcessKiller. +// If this is not the case, regenerate this file with moq. +var _ ProcessKiller = &ProcessKillerMock{} + +// ProcessKillerMock is a mock implementation of ProcessKiller. +// +// func TestSomethingThatUsesProcessKiller(t *testing.T) { +// +// // make and configure a mocked ProcessKiller +// mockedProcessKiller := &ProcessKillerMock{ +// KillByPatternFunc: func(ctx context.Context, pattern string) error { +// panic("mock out the KillByPattern method") +// }, +// } +// +// // use mockedProcessKiller in code that requires ProcessKiller +// // and then make assertions. +// +// } +type ProcessKillerMock struct { + // KillByPatternFunc mocks the KillByPattern method. + KillByPatternFunc func(ctx context.Context, pattern string) error + + // calls tracks calls to the methods. + calls struct { + // KillByPattern holds details about calls to the KillByPattern method. + KillByPattern []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Pattern is the pattern argument value. + Pattern string + } + } + lockKillByPattern sync.RWMutex +} + +// KillByPattern calls KillByPatternFunc. +func (mock *ProcessKillerMock) KillByPattern(ctx context.Context, pattern string) error { + callInfo := struct { + Ctx context.Context + Pattern string + }{ + Ctx: ctx, + Pattern: pattern, + } + mock.lockKillByPattern.Lock() + mock.calls.KillByPattern = append(mock.calls.KillByPattern, callInfo) + mock.lockKillByPattern.Unlock() + if mock.KillByPatternFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.KillByPatternFunc(ctx, pattern) +} + +// KillByPatternCalls gets all the calls that were made to KillByPattern. +// Check the length with: +// +// len(mockedProcessKiller.KillByPatternCalls()) +func (mock *ProcessKillerMock) KillByPatternCalls() []struct { + Ctx context.Context + Pattern string +} { + var calls []struct { + Ctx context.Context + Pattern string + } + mock.lockKillByPattern.RLock() + calls = mock.calls.KillByPattern + mock.lockKillByPattern.RUnlock() + return calls +} diff --git a/internal/process/process_test.go b/internal/process/process_test.go index b3e9f8b4..a322ac12 100644 --- a/internal/process/process_test.go +++ b/internal/process/process_test.go @@ -1,38 +1,15 @@ -package process +package process_test import ( "context" "errors" "testing" + "github.com/newhook/co/internal/process" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -// mockProcessLister is a mock implementation of ProcessLister for testing. -type mockProcessLister struct { - processes []string - err error -} - -func (m *mockProcessLister) GetProcessList(ctx context.Context) ([]string, error) { - if m.err != nil { - return nil, m.err - } - return m.processes, nil -} - -// mockProcessKiller is a mock implementation of ProcessKiller for testing. -type mockProcessKiller struct { - killedPatterns []string - err error -} - -func (m *mockProcessKiller) KillByPattern(ctx context.Context, pattern string) error { - m.killedPatterns = append(m.killedPatterns, pattern) - return m.err -} - func TestIsProcessRunningWith(t *testing.T) { ctx := context.Background() @@ -82,8 +59,12 @@ func TestIsProcessRunningWith(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - lister := &mockProcessLister{processes: tt.processes} - got, err := IsProcessRunningWith(ctx, tt.pattern, lister) + lister := &process.ProcessListerMock{ + GetProcessListFunc: func(ctx context.Context) ([]string, error) { + return tt.processes, nil + }, + } + got, err := process.IsProcessRunningWith(ctx, tt.pattern, lister) if tt.wantErr { require.Error(t, err) @@ -98,9 +79,13 @@ func TestIsProcessRunningWith(t *testing.T) { func TestIsProcessRunningWith_ListerError(t *testing.T) { ctx := context.Background() - lister := &mockProcessLister{err: errors.New("ps command failed")} + lister := &process.ProcessListerMock{ + GetProcessListFunc: func(ctx context.Context) ([]string, error) { + return nil, errors.New("ps command failed") + }, + } - _, err := IsProcessRunningWith(ctx, "myapp", lister) + _, err := process.IsProcessRunningWith(ctx, "myapp", lister) require.Error(t, err) assert.Contains(t, err.Error(), "failed to get process list") } @@ -149,10 +134,20 @@ func TestKillProcessWith(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - lister := &mockProcessLister{processes: tt.processes} - killer := &mockProcessKiller{err: tt.killerErr} + lister := &process.ProcessListerMock{ + GetProcessListFunc: func(ctx context.Context) ([]string, error) { + return tt.processes, nil + }, + } + var killedPatterns []string + killer := &process.ProcessKillerMock{ + KillByPatternFunc: func(ctx context.Context, pattern string) error { + killedPatterns = append(killedPatterns, pattern) + return tt.killerErr + }, + } - err := KillProcessWith(ctx, tt.pattern, lister, killer) + err := process.KillProcessWith(ctx, tt.pattern, lister, killer) if tt.wantErr { require.Error(t, err) @@ -161,9 +156,9 @@ func TestKillProcessWith(t *testing.T) { } if tt.wantKillCalled { - assert.Contains(t, killer.killedPatterns, tt.pattern) + assert.Contains(t, killedPatterns, tt.pattern) } else { - assert.Empty(t, killer.killedPatterns) + assert.Empty(t, killedPatterns) } }) } @@ -171,62 +166,15 @@ func TestKillProcessWith(t *testing.T) { func TestKillProcessWith_ListerError(t *testing.T) { ctx := context.Background() - lister := &mockProcessLister{err: errors.New("ps command failed")} - killer := &mockProcessKiller{} - - err := KillProcessWith(ctx, "myapp", lister, killer) - require.Error(t, err) - assert.Contains(t, err.Error(), "failed to get process list") - assert.Empty(t, killer.killedPatterns) -} - -func TestEscapePattern(t *testing.T) { - tests := []struct { - name string - pattern string - want string - }{ - { - name: "simple pattern", - pattern: "simple", - want: "'simple'", - }, - { - name: "pattern with single quote", - pattern: "test'pattern", - want: "'test'\\''pattern'", - }, - { - name: "pattern with multiple single quotes", - pattern: "test'pattern'here", - want: "'test'\\''pattern'\\''here'", - }, - { - name: "pattern with special characters", - pattern: "test$pattern*here", - want: "'test$pattern*here'", - }, - { - name: "empty pattern", - pattern: "", - want: "''", - }, - { - name: "pattern with spaces", - pattern: "test pattern", - want: "'test pattern'", - }, - { - name: "pattern with newline", - pattern: "test\npattern", - want: "'test\npattern'", + lister := &process.ProcessListerMock{ + GetProcessListFunc: func(ctx context.Context) ([]string, error) { + return nil, errors.New("ps command failed") }, } + killer := &process.ProcessKillerMock{} - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := escapePattern(tt.pattern) - assert.Equal(t, tt.want, got) - }) - } + err := process.KillProcessWith(ctx, "myapp", lister, killer) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to get process list") + assert.Empty(t, killer.KillByPatternCalls()) } diff --git a/internal/project/config_test.go b/internal/project/config_test.go index b838129d..3dbbaf23 100644 --- a/internal/project/config_test.go +++ b/internal/project/config_test.go @@ -3,7 +3,9 @@ package project import ( "testing" "time" + "github.com/BurntSushi/toml" + "github.com/stretchr/testify/require" ) func TestGeneratedConfigIsValidTOML(t *testing.T) { @@ -22,20 +24,15 @@ func TestGeneratedConfigIsValidTOML(t *testing.T) { // Try to parse the generated content as valid TOML var parsed map[string]interface{} - if _, err := toml.Decode(content, &parsed); err != nil { - t.Errorf("Generated config is not valid TOML: %v\n\nContent:\n%s", err, content) - } + _, err := toml.Decode(content, &parsed) + require.NoError(t, err, "Generated config is not valid TOML:\n%s", content) // Verify key fields are present project := parsed["project"].(map[string]interface{}) - if project["name"] != "test-project" { - t.Errorf("Expected project.name to be 'test-project', got %v", project["name"]) - } + require.Equal(t, "test-project", project["name"]) repo := parsed["repo"].(map[string]interface{}) - if repo["type"] != "github" { - t.Errorf("Expected repo.type to be 'github', got %v", repo["type"]) - } + require.Equal(t, "github", repo["type"]) } func TestGeneratedConfigRoundTrip(t *testing.T) { @@ -55,26 +52,15 @@ func TestGeneratedConfigRoundTrip(t *testing.T) { // Parse back into a Config struct var loaded Config - if _, err := toml.Decode(content, &loaded); err != nil { - t.Fatalf("Failed to decode generated config: %v", err) - } + _, err := toml.Decode(content, &loaded) + require.NoError(t, err) // Verify fields match - if loaded.Project.Name != original.Project.Name { - t.Errorf("Project.Name: expected %q, got %q", original.Project.Name, loaded.Project.Name) - } - if !loaded.Project.CreatedAt.Equal(original.Project.CreatedAt) { - t.Errorf("Project.CreatedAt: expected %v, got %v", original.Project.CreatedAt, loaded.Project.CreatedAt) - } - if loaded.Repo.Type != original.Repo.Type { - t.Errorf("Repo.Type: expected %q, got %q", original.Repo.Type, loaded.Repo.Type) - } - if loaded.Repo.Source != original.Repo.Source { - t.Errorf("Repo.Source: expected %q, got %q", original.Repo.Source, loaded.Repo.Source) - } - if loaded.Repo.Path != original.Repo.Path { - t.Errorf("Repo.Path: expected %q, got %q", original.Repo.Path, loaded.Repo.Path) - } + require.Equal(t, original.Project.Name, loaded.Project.Name) + require.True(t, loaded.Project.CreatedAt.Equal(original.Project.CreatedAt)) + require.Equal(t, original.Repo.Type, loaded.Repo.Type) + require.Equal(t, original.Repo.Source, loaded.Repo.Source) + require.Equal(t, original.Repo.Path, loaded.Repo.Path) } func TestGeneratedConfigWithSpecialCharacters(t *testing.T) { @@ -95,17 +81,12 @@ func TestGeneratedConfigWithSpecialCharacters(t *testing.T) { // This should parse successfully even with special characters var parsed Config - if _, err := toml.Decode(content, &parsed); err != nil { - t.Errorf("Failed to parse config with special characters: %v\n\nContent:\n%s", err, content) - } + _, err := toml.Decode(content, &parsed) + require.NoError(t, err, "Failed to parse config with special characters:\n%s", content) // Verify values - if parsed.Project.Name != cfg.Project.Name { - t.Errorf("Project.Name: expected %q, got %q", cfg.Project.Name, parsed.Project.Name) - } - if parsed.Repo.Source != cfg.Repo.Source { - t.Errorf("Repo.Source: expected %q, got %q", cfg.Repo.Source, parsed.Repo.Source) - } + require.Equal(t, cfg.Project.Name, parsed.Project.Name) + require.Equal(t, cfg.Repo.Source, parsed.Repo.Source) } func TestShouldSkipPermissionsDefault(t *testing.T) { @@ -115,13 +96,10 @@ func TestShouldSkipPermissionsDefault(t *testing.T) { name = "test" ` var cfg Config - if _, err := toml.Decode(tomlContent, &cfg); err != nil { - t.Fatalf("Failed to decode: %v", err) - } + _, err := toml.Decode(tomlContent, &cfg) + require.NoError(t, err) - if !cfg.Claude.ShouldSkipPermissions() { - t.Error("Expected ShouldSkipPermissions() to return true by default, got false") - } + require.True(t, cfg.Claude.ShouldSkipPermissions(), "Expected ShouldSkipPermissions() to return true by default") } func TestShouldSkipPermissionsExplicitFalse(t *testing.T) { @@ -134,13 +112,10 @@ func TestShouldSkipPermissionsExplicitFalse(t *testing.T) { skip_permissions = false ` var cfg Config - if _, err := toml.Decode(tomlContent, &cfg); err != nil { - t.Fatalf("Failed to decode: %v", err) - } + _, err := toml.Decode(tomlContent, &cfg) + require.NoError(t, err) - if cfg.Claude.ShouldSkipPermissions() { - t.Error("Expected ShouldSkipPermissions() to return false when explicitly set, got true") - } + require.False(t, cfg.Claude.ShouldSkipPermissions(), "Expected ShouldSkipPermissions() to return false when explicitly set") } func TestGeneratedConfigWithUTF8(t *testing.T) { @@ -161,14 +136,11 @@ func TestGeneratedConfigWithUTF8(t *testing.T) { // This should parse successfully var parsed Config - if _, err := toml.Decode(content, &parsed); err != nil { - t.Errorf("Failed to parse config with UTF-8: %v", err) - } + _, err := toml.Decode(content, &parsed) + require.NoError(t, err, "Failed to parse config with UTF-8") // Verify values - if parsed.Project.Name != cfg.Project.Name { - t.Errorf("Project.Name: expected %q, got %q", cfg.Project.Name, parsed.Project.Name) - } + require.Equal(t, cfg.Project.Name, parsed.Project.Name) } func TestLogParserConfig_ShouldUseClaude(t *testing.T) { @@ -202,9 +174,7 @@ func TestLogParserConfig_ShouldUseClaude(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := tt.config.ShouldUseClaude() - if result != tt.expected { - t.Errorf("ShouldUseClaude() = %v, want %v", result, tt.expected) - } + require.Equal(t, tt.expected, result) }) } } @@ -240,19 +210,17 @@ func TestLogParserConfig_GetModel(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := tt.config.GetModel() - if result != tt.expected { - t.Errorf("GetModel() = %q, want %q", result, tt.expected) - } + require.Equal(t, tt.expected, result) }) } } func TestLogParserConfigFromTOML(t *testing.T) { tests := []struct { - name string - tomlContent string - wantUseClaude bool - wantModel string + name string + tomlContent string + wantUseClaude bool + wantModel string }{ { name: "Not specified defaults", @@ -306,17 +274,11 @@ use_claude = true for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var cfg Config - if _, err := toml.Decode(tt.tomlContent, &cfg); err != nil { - t.Fatalf("Failed to decode TOML: %v", err) - } - - if cfg.LogParser.ShouldUseClaude() != tt.wantUseClaude { - t.Errorf("ShouldUseClaude() = %v, want %v", cfg.LogParser.ShouldUseClaude(), tt.wantUseClaude) - } + _, err := toml.Decode(tt.tomlContent, &cfg) + require.NoError(t, err) - if cfg.LogParser.GetModel() != tt.wantModel { - t.Errorf("GetModel() = %q, want %q", cfg.LogParser.GetModel(), tt.wantModel) - } + require.Equal(t, tt.wantUseClaude, cfg.LogParser.ShouldUseClaude()) + require.Equal(t, tt.wantModel, cfg.LogParser.GetModel()) }) } } diff --git a/internal/task/internal_test.go b/internal/task/internal_test.go new file mode 100644 index 00000000..96472186 --- /dev/null +++ b/internal/task/internal_test.go @@ -0,0 +1,152 @@ +package task + +import ( + "testing" + + "github.com/newhook/co/internal/beads" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCanAddToTask(t *testing.T) { + graph := &DependencyGraph{ + DependsOn: map[string][]string{ + "b": {"a"}, + }, + Dependents: map[string][]string{ + "a": {"b"}, + }, + } + + assigned := map[string]int{ + "a": 0, // a is in task 0 + } + + // b depends on a which is in task 0, so b can be added to task 0 or later + assert.True(t, canAddToTask("b", 0, assigned, graph), "b should be addable to task 0 (same as dependency)") + assert.True(t, canAddToTask("b", 1, assigned, graph), "b should be addable to task 1 (after dependency)") + + // Can't add b to task before a + assigned["a"] = 1 + assert.False(t, canAddToTask("b", 0, assigned, graph), "b should not be addable to task 0 (before dependency in task 1)") +} + +// ComputeInterTaskDeps is a helper function that mimics the inter-task dependency +// computation logic from handlePostEstimation. Used for testing. +func ComputeInterTaskDeps(tasks []Task, dependencies map[string][]beads.Dependency) map[int][]int { + // Build beadID → task index mapping + beadToTask := make(map[string]int) + for i, t := range tasks { + for _, beadID := range t.BeadIDs { + beadToTask[beadID] = i + } + } + + // Compute inter-task dependencies + // Returns map of taskIdx → list of task indices it depends on + interTaskDeps := make(map[int]map[int]bool) + for beadID, deps := range dependencies { + taskIdx, ok := beadToTask[beadID] + if !ok { + continue + } + for _, dep := range deps { + depTaskIdx, ok := beadToTask[dep.DependsOnID] + if !ok { + continue + } + if taskIdx == depTaskIdx { + continue // same task, no inter-task dependency + } + if interTaskDeps[taskIdx] == nil { + interTaskDeps[taskIdx] = make(map[int]bool) + } + interTaskDeps[taskIdx][depTaskIdx] = true + } + } + + // Convert to slice representation + result := make(map[int][]int) + for taskIdx, depSet := range interTaskDeps { + for depIdx := range depSet { + result[taskIdx] = append(result[taskIdx], depIdx) + } + } + return result +} + +func TestComputeInterTaskDepsChain(t *testing.T) { + // Chain: c depends on b, b depends on a - each in separate task + tasks := []Task{ + {ID: "task-1", BeadIDs: []string{"a"}}, + {ID: "task-2", BeadIDs: []string{"b"}}, + {ID: "task-3", BeadIDs: []string{"c"}}, + } + + dependencies := map[string][]beads.Dependency{ + "b": {{IssueID: "b", DependsOnID: "a", Type: "blocks"}}, + "c": {{IssueID: "c", DependsOnID: "b", Type: "blocks"}}, + } + + interDeps := ComputeInterTaskDeps(tasks, dependencies) + + // task-2 (index 1) should depend on task-1 (index 0) + require.Contains(t, interDeps, 1, "task-2 should have dependencies") + assert.Contains(t, interDeps[1], 0, "task-2 should depend on task-1") + + // task-3 (index 2) should depend on task-2 (index 1) + require.Contains(t, interDeps, 2, "task-3 should have dependencies") + assert.Contains(t, interDeps[2], 1, "task-3 should depend on task-2") + + // task-1 (index 0) should have no dependencies + assert.NotContains(t, interDeps, 0, "task-1 should have no inter-task dependencies") +} + +func TestComputeInterTaskDepsDiamond(t *testing.T) { + // Diamond: a and b independent, c depends on both, d depends on c + tasks := []Task{ + {ID: "task-1", BeadIDs: []string{"a"}}, + {ID: "task-2", BeadIDs: []string{"b"}}, + {ID: "task-3", BeadIDs: []string{"c"}}, + {ID: "task-4", BeadIDs: []string{"d"}}, + } + + dependencies := map[string][]beads.Dependency{ + "c": { + {IssueID: "c", DependsOnID: "a", Type: "blocks"}, + {IssueID: "c", DependsOnID: "b", Type: "blocks"}, + }, + "d": {{IssueID: "d", DependsOnID: "c", Type: "blocks"}}, + } + + interDeps := ComputeInterTaskDeps(tasks, dependencies) + + // task-3 (index 2) should depend on both task-1 (index 0) and task-2 (index 1) + require.Contains(t, interDeps, 2, "task-3 should have dependencies") + assert.Contains(t, interDeps[2], 0, "task-3 should depend on task-1") + assert.Contains(t, interDeps[2], 1, "task-3 should depend on task-2") + + // task-4 (index 3) should depend on task-3 (index 2) + require.Contains(t, interDeps, 3, "task-4 should have dependencies") + assert.Contains(t, interDeps[3], 2, "task-4 should depend on task-3") + + // task-1 and task-2 should have no dependencies + assert.NotContains(t, interDeps, 0, "task-1 should have no inter-task dependencies") + assert.NotContains(t, interDeps, 1, "task-2 should have no inter-task dependencies") +} + +func TestComputeInterTaskDepsSameTaskNoDeps(t *testing.T) { + // Both beads in same task, b depends on a + tasks := []Task{ + {ID: "task-1", BeadIDs: []string{"a", "b"}}, + } + + dependencies := map[string][]beads.Dependency{ + "b": {{IssueID: "b", DependsOnID: "a", Type: "blocks"}}, + } + + interDeps := ComputeInterTaskDeps(tasks, dependencies) + + // No inter-task dependencies since both beads are in the same task + assert.Empty(t, interDeps, "same-task dependencies should not create inter-task deps") +} diff --git a/internal/task/planner_test.go b/internal/task/planner_test.go index 3dbad565..c63a307f 100644 --- a/internal/task/planner_test.go +++ b/internal/task/planner_test.go @@ -1,27 +1,15 @@ -package task +package task_test import ( "context" "testing" "github.com/newhook/co/internal/beads" + "github.com/newhook/co/internal/task" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -// mockEstimator returns fixed complexity scores for testing. -type mockEstimator struct { - scores map[string]int -} - -func (m *mockEstimator) Estimate(ctx context.Context, bead beads.Bead) (int, int, error) { - score := m.scores[bead.ID] - if score == 0 { - score = 5 // default - } - return score, score * 1000, nil -} - func TestBuildDependencyGraph(t *testing.T) { beadList := []beads.Bead{ {ID: "a", Title: "A"}, @@ -34,7 +22,7 @@ func TestBuildDependencyGraph(t *testing.T) { "c": {{IssueID: "c", DependsOnID: "b", Type: "blocks"}}, } - graph := BuildDependencyGraph(beadList, dependencies) + graph := task.BuildDependencyGraph(beadList, dependencies) // b depends on a require.Len(t, graph.DependsOn["b"], 1, "expected b to have 1 dependency") @@ -59,7 +47,7 @@ func TestBuildDependencyGraphIgnoresExternalDeps(t *testing.T) { "a": {{IssueID: "a", DependsOnID: "external", Type: "blocks"}}, } - graph := BuildDependencyGraph(beadList, dependencies) + graph := task.BuildDependencyGraph(beadList, dependencies) // External dependency should be filtered out since "external" is not in the beads list assert.Empty(t, graph.DependsOn["a"], "external dependency should be ignored") @@ -78,8 +66,8 @@ func TestTopologicalSort(t *testing.T) { "b": {{IssueID: "b", DependsOnID: "a", Type: "blocks"}}, } - graph := BuildDependencyGraph(beadList, dependencies) - sorted, err := TopologicalSort(graph, beadList) + graph := task.BuildDependencyGraph(beadList, dependencies) + sorted, err := task.TopologicalSort(graph, beadList) require.NoError(t, err, "TopologicalSort failed") // a should come before b, b should come before c @@ -103,17 +91,24 @@ func TestTopologicalSortDetectsCycle(t *testing.T) { "b": {{IssueID: "b", DependsOnID: "a", Type: "blocks"}}, } - graph := BuildDependencyGraph(beadList, dependencies) - _, err := TopologicalSort(graph, beadList) + graph := task.BuildDependencyGraph(beadList, dependencies) + _, err := task.TopologicalSort(graph, beadList) assert.Error(t, err, "expected error for cycle detection") } func TestPlanSimple(t *testing.T) { ctx := context.Background() - estimator := &mockEstimator{ - scores: map[string]int{"a": 3, "b": 3, "c": 3}, + scores := map[string]int{"a": 3, "b": 3, "c": 3} + estimator := &task.ComplexityEstimatorMock{ + EstimateFunc: func(ctx context.Context, bead beads.Bead) (int, int, error) { + score := scores[bead.ID] + if score == 0 { + score = 5 // default + } + return score, score * 1000, nil + }, } - planner := NewDefaultPlanner(estimator) + planner := task.NewDefaultPlanner(estimator) beadList := []beads.Bead{ {ID: "a", Title: "A"}, @@ -133,10 +128,17 @@ func TestPlanSimple(t *testing.T) { func TestPlanSplitByBudget(t *testing.T) { ctx := context.Background() - estimator := &mockEstimator{ - scores: map[string]int{"a": 5, "b": 5, "c": 5}, + scores := map[string]int{"a": 5, "b": 5, "c": 5} + estimator := &task.ComplexityEstimatorMock{ + EstimateFunc: func(ctx context.Context, bead beads.Bead) (int, int, error) { + score := scores[bead.ID] + if score == 0 { + score = 5 // default + } + return score, score * 1000, nil + }, } - planner := NewDefaultPlanner(estimator) + planner := task.NewDefaultPlanner(estimator) beadList := []beads.Bead{ {ID: "a", Title: "A"}, @@ -154,18 +156,25 @@ func TestPlanSplitByBudget(t *testing.T) { // Verify all beads are assigned totalBeads := 0 - for _, task := range tasks { - totalBeads += len(task.BeadIDs) + for _, t := range tasks { + totalBeads += len(t.BeadIDs) } assert.Equal(t, 3, totalBeads, "expected 3 total beads") } func TestPlanRespectsDependencies(t *testing.T) { ctx := context.Background() - estimator := &mockEstimator{ - scores: map[string]int{"a": 3, "b": 3}, + scores := map[string]int{"a": 3, "b": 3} + estimator := &task.ComplexityEstimatorMock{ + EstimateFunc: func(ctx context.Context, bead beads.Bead) (int, int, error) { + score := scores[bead.ID] + if score == 0 { + score = 5 // default + } + return score, score * 1000, nil + }, } - planner := NewDefaultPlanner(estimator) + planner := task.NewDefaultPlanner(estimator) beadList := []beads.Bead{ {ID: "a", Title: "A"}, @@ -182,8 +191,8 @@ func TestPlanRespectsDependencies(t *testing.T) { // Find which tasks contain a and b taskForBead := make(map[string]int) - for i, task := range tasks { - for _, id := range task.BeadIDs { + for i, t := range tasks { + for _, id := range t.BeadIDs { taskForBead[id] = i } } @@ -194,8 +203,8 @@ func TestPlanRespectsDependencies(t *testing.T) { func TestPlanEmpty(t *testing.T) { ctx := context.Background() - estimator := &mockEstimator{} - planner := NewDefaultPlanner(estimator) + estimator := &task.ComplexityEstimatorMock{} + planner := task.NewDefaultPlanner(estimator) dependencies := map[string][]beads.Dependency{} @@ -208,10 +217,17 @@ func TestPlanEmpty(t *testing.T) { func TestPlanFirstFitDecreasing(t *testing.T) { ctx := context.Background() // Larger beads are assigned first (by token estimate) - estimator := &mockEstimator{ - scores: map[string]int{"small": 2, "medium": 4, "large": 6}, + scores := map[string]int{"small": 2, "medium": 4, "large": 6} + estimator := &task.ComplexityEstimatorMock{ + EstimateFunc: func(ctx context.Context, bead beads.Bead) (int, int, error) { + score := scores[bead.ID] + if score == 0 { + score = 5 // default + } + return score, score * 1000, nil + }, } - planner := NewDefaultPlanner(estimator) + planner := task.NewDefaultPlanner(estimator) beadList := []beads.Bead{ {ID: "small", Title: "Small"}, @@ -229,38 +245,20 @@ func TestPlanFirstFitDecreasing(t *testing.T) { assert.Len(t, tasks, 2, "expected 2 tasks") } -func TestCanAddToTask(t *testing.T) { - graph := &DependencyGraph{ - DependsOn: map[string][]string{ - "b": {"a"}, - }, - Dependents: map[string][]string{ - "a": {"b"}, - }, - } - - assigned := map[string]int{ - "a": 0, // a is in task 0 - } - - // b depends on a which is in task 0, so b can be added to task 0 or later - assert.True(t, canAddToTask("b", 0, assigned, graph), "b should be addable to task 0 (same as dependency)") - assert.True(t, canAddToTask("b", 1, assigned, graph), "b should be addable to task 1 (after dependency)") - - // Can't add b to task before a - assigned["a"] = 1 - assert.False(t, canAddToTask("b", 0, assigned, graph), "b should not be addable to task 0 (before dependency in task 1)") -} - -// TestPlanChainDependencySplitAcrossTasks tests that beads with chain dependency -// (A→B→C) split across tasks create a task chain with correct ordering. func TestPlanChainDependencySplitAcrossTasks(t *testing.T) { ctx := context.Background() // Small budget to force each bead into separate task - estimator := &mockEstimator{ - scores: map[string]int{"a": 5, "b": 5, "c": 5}, + scores := map[string]int{"a": 5, "b": 5, "c": 5} + estimator := &task.ComplexityEstimatorMock{ + EstimateFunc: func(ctx context.Context, bead beads.Bead) (int, int, error) { + score := scores[bead.ID] + if score == 0 { + score = 5 // default + } + return score, score * 1000, nil + }, } - planner := NewDefaultPlanner(estimator) + planner := task.NewDefaultPlanner(estimator) beadList := []beads.Bead{ {ID: "a", Title: "A"}, @@ -281,8 +279,8 @@ func TestPlanChainDependencySplitAcrossTasks(t *testing.T) { // Find which task contains each bead taskForBead := make(map[string]int) - for i, task := range tasks { - for _, id := range task.BeadIDs { + for i, t := range tasks { + for _, id := range t.BeadIDs { taskForBead[id] = i } } @@ -292,14 +290,19 @@ func TestPlanChainDependencySplitAcrossTasks(t *testing.T) { assert.Less(t, taskForBead["b"], taskForBead["c"], "b should be in earlier task than c") } -// TestPlanDiamondDependencySplitAcrossTasks tests that beads with diamond dependency -// create the correct task graph. Diamond: A, B depend on nothing; C depends on both A and B. func TestPlanDiamondDependencySplitAcrossTasks(t *testing.T) { ctx := context.Background() - estimator := &mockEstimator{ - scores: map[string]int{"a": 5, "b": 5, "c": 5, "d": 5}, + scores := map[string]int{"a": 5, "b": 5, "c": 5, "d": 5} + estimator := &task.ComplexityEstimatorMock{ + EstimateFunc: func(ctx context.Context, bead beads.Bead) (int, int, error) { + score := scores[bead.ID] + if score == 0 { + score = 5 // default + } + return score, score * 1000, nil + }, } - planner := NewDefaultPlanner(estimator) + planner := task.NewDefaultPlanner(estimator) // Diamond: a and b are independent, c depends on both, d depends on c beadList := []beads.Bead{ @@ -324,8 +327,8 @@ func TestPlanDiamondDependencySplitAcrossTasks(t *testing.T) { // Find which task contains each bead taskForBead := make(map[string]int) - for i, task := range tasks { - for _, id := range task.BeadIDs { + for i, t := range tasks { + for _, id := range t.BeadIDs { taskForBead[id] = i } } @@ -337,14 +340,19 @@ func TestPlanDiamondDependencySplitAcrossTasks(t *testing.T) { assert.Less(t, taskForBead["c"], taskForBead["d"], "c should be in earlier task than d") } -// TestPlanSameTaskDependenciesNoSelfDep tests that beads in the same task -// with dependencies do not create self-dependency (task depending on itself). func TestPlanSameTaskDependenciesNoSelfDep(t *testing.T) { ctx := context.Background() - estimator := &mockEstimator{ - scores: map[string]int{"a": 2, "b": 2}, + scores := map[string]int{"a": 2, "b": 2} + estimator := &task.ComplexityEstimatorMock{ + EstimateFunc: func(ctx context.Context, bead beads.Bead) (int, int, error) { + score := scores[bead.ID] + if score == 0 { + score = 5 // default + } + return score, score * 1000, nil + }, } - planner := NewDefaultPlanner(estimator) + planner := task.NewDefaultPlanner(estimator) beadList := []beads.Bead{ {ID: "a", Title: "A"}, @@ -363,135 +371,11 @@ func TestPlanSameTaskDependenciesNoSelfDep(t *testing.T) { // Both beads should be in the same task taskForBead := make(map[string]int) - for i, task := range tasks { - for _, id := range task.BeadIDs { + for i, t := range tasks { + for _, id := range t.BeadIDs { taskForBead[id] = i } } assert.Equal(t, taskForBead["a"], taskForBead["b"], "a and b should be in the same task") } - -// ComputeInterTaskDeps is a helper function that mimics the inter-task dependency -// computation logic from handlePostEstimation. Used for testing. -func ComputeInterTaskDeps(tasks []Task, dependencies map[string][]beads.Dependency) map[int][]int { - // Build beadID → task index mapping - beadToTask := make(map[string]int) - for i, t := range tasks { - for _, beadID := range t.BeadIDs { - beadToTask[beadID] = i - } - } - - // Compute inter-task dependencies - // Returns map of taskIdx → list of task indices it depends on - interTaskDeps := make(map[int]map[int]bool) - for beadID, deps := range dependencies { - taskIdx, ok := beadToTask[beadID] - if !ok { - continue - } - for _, dep := range deps { - depTaskIdx, ok := beadToTask[dep.DependsOnID] - if !ok { - continue - } - if taskIdx == depTaskIdx { - continue // same task, no inter-task dependency - } - if interTaskDeps[taskIdx] == nil { - interTaskDeps[taskIdx] = make(map[int]bool) - } - interTaskDeps[taskIdx][depTaskIdx] = true - } - } - - // Convert to slice representation - result := make(map[int][]int) - for taskIdx, depSet := range interTaskDeps { - for depIdx := range depSet { - result[taskIdx] = append(result[taskIdx], depIdx) - } - } - return result -} - -// TestComputeInterTaskDepsChain tests inter-task dependency computation for chain. -func TestComputeInterTaskDepsChain(t *testing.T) { - // Chain: c depends on b, b depends on a - each in separate task - tasks := []Task{ - {ID: "task-1", BeadIDs: []string{"a"}}, - {ID: "task-2", BeadIDs: []string{"b"}}, - {ID: "task-3", BeadIDs: []string{"c"}}, - } - - dependencies := map[string][]beads.Dependency{ - "b": {{IssueID: "b", DependsOnID: "a", Type: "blocks"}}, - "c": {{IssueID: "c", DependsOnID: "b", Type: "blocks"}}, - } - - interDeps := ComputeInterTaskDeps(tasks, dependencies) - - // task-2 (index 1) should depend on task-1 (index 0) - require.Contains(t, interDeps, 1, "task-2 should have dependencies") - assert.Contains(t, interDeps[1], 0, "task-2 should depend on task-1") - - // task-3 (index 2) should depend on task-2 (index 1) - require.Contains(t, interDeps, 2, "task-3 should have dependencies") - assert.Contains(t, interDeps[2], 1, "task-3 should depend on task-2") - - // task-1 (index 0) should have no dependencies - assert.NotContains(t, interDeps, 0, "task-1 should have no inter-task dependencies") -} - -// TestComputeInterTaskDepsDiamond tests inter-task dependency computation for diamond. -func TestComputeInterTaskDepsDiamond(t *testing.T) { - // Diamond: a and b independent, c depends on both, d depends on c - tasks := []Task{ - {ID: "task-1", BeadIDs: []string{"a"}}, - {ID: "task-2", BeadIDs: []string{"b"}}, - {ID: "task-3", BeadIDs: []string{"c"}}, - {ID: "task-4", BeadIDs: []string{"d"}}, - } - - dependencies := map[string][]beads.Dependency{ - "c": { - {IssueID: "c", DependsOnID: "a", Type: "blocks"}, - {IssueID: "c", DependsOnID: "b", Type: "blocks"}, - }, - "d": {{IssueID: "d", DependsOnID: "c", Type: "blocks"}}, - } - - interDeps := ComputeInterTaskDeps(tasks, dependencies) - - // task-3 (index 2) should depend on both task-1 (index 0) and task-2 (index 1) - require.Contains(t, interDeps, 2, "task-3 should have dependencies") - assert.Contains(t, interDeps[2], 0, "task-3 should depend on task-1") - assert.Contains(t, interDeps[2], 1, "task-3 should depend on task-2") - - // task-4 (index 3) should depend on task-3 (index 2) - require.Contains(t, interDeps, 3, "task-4 should have dependencies") - assert.Contains(t, interDeps[3], 2, "task-4 should depend on task-3") - - // task-1 and task-2 should have no dependencies - assert.NotContains(t, interDeps, 0, "task-1 should have no inter-task dependencies") - assert.NotContains(t, interDeps, 1, "task-2 should have no inter-task dependencies") -} - -// TestComputeInterTaskDepsSameTaskNoDeps tests that beads in the same task -// do not create self-dependencies. -func TestComputeInterTaskDepsSameTaskNoDeps(t *testing.T) { - // Both beads in same task, b depends on a - tasks := []Task{ - {ID: "task-1", BeadIDs: []string{"a", "b"}}, - } - - dependencies := map[string][]beads.Dependency{ - "b": {{IssueID: "b", DependsOnID: "a", Type: "blocks"}}, - } - - interDeps := ComputeInterTaskDeps(tasks, dependencies) - - // No inter-task dependencies since both beads are in the same task - assert.Empty(t, interDeps, "same-task dependencies should not create inter-task deps") -} diff --git a/internal/task/task.go b/internal/task/task.go index 71a936a8..1aacfb08 100644 --- a/internal/task/task.go +++ b/internal/task/task.go @@ -1,5 +1,7 @@ package task +//go:generate moq -stub -out task_mock.go . ComplexityEstimator + import ( "context" diff --git a/internal/task/task_mock.go b/internal/task/task_mock.go new file mode 100644 index 00000000..b6b0ffdf --- /dev/null +++ b/internal/task/task_mock.go @@ -0,0 +1,87 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package task + +import ( + "context" + "github.com/newhook/co/internal/beads" + "sync" +) + +// Ensure, that ComplexityEstimatorMock does implement ComplexityEstimator. +// If this is not the case, regenerate this file with moq. +var _ ComplexityEstimator = &ComplexityEstimatorMock{} + +// ComplexityEstimatorMock is a mock implementation of ComplexityEstimator. +// +// func TestSomethingThatUsesComplexityEstimator(t *testing.T) { +// +// // make and configure a mocked ComplexityEstimator +// mockedComplexityEstimator := &ComplexityEstimatorMock{ +// EstimateFunc: func(ctx context.Context, bead beads.Bead) (int, int, error) { +// panic("mock out the Estimate method") +// }, +// } +// +// // use mockedComplexityEstimator in code that requires ComplexityEstimator +// // and then make assertions. +// +// } +type ComplexityEstimatorMock struct { + // EstimateFunc mocks the Estimate method. + EstimateFunc func(ctx context.Context, bead beads.Bead) (int, int, error) + + // calls tracks calls to the methods. + calls struct { + // Estimate holds details about calls to the Estimate method. + Estimate []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Bead is the bead argument value. + Bead beads.Bead + } + } + lockEstimate sync.RWMutex +} + +// Estimate calls EstimateFunc. +func (mock *ComplexityEstimatorMock) Estimate(ctx context.Context, bead beads.Bead) (int, int, error) { + callInfo := struct { + Ctx context.Context + Bead beads.Bead + }{ + Ctx: ctx, + Bead: bead, + } + mock.lockEstimate.Lock() + mock.calls.Estimate = append(mock.calls.Estimate, callInfo) + mock.lockEstimate.Unlock() + if mock.EstimateFunc == nil { + var ( + scoreOut int + tokensOut int + errOut error + ) + return scoreOut, tokensOut, errOut + } + return mock.EstimateFunc(ctx, bead) +} + +// EstimateCalls gets all the calls that were made to Estimate. +// Check the length with: +// +// len(mockedComplexityEstimator.EstimateCalls()) +func (mock *ComplexityEstimatorMock) EstimateCalls() []struct { + Ctx context.Context + Bead beads.Bead +} { + var calls []struct { + Ctx context.Context + Bead beads.Bead + } + mock.lockEstimate.RLock() + calls = mock.calls.Estimate + mock.lockEstimate.RUnlock() + return calls +} diff --git a/internal/tui/tui_plan_dialogs_test.go b/internal/tui/tui_plan_dialogs_test.go index b7c56548..e9144d73 100644 --- a/internal/tui/tui_plan_dialogs_test.go +++ b/internal/tui/tui_plan_dialogs_test.go @@ -6,6 +6,7 @@ import ( "testing" tea "github.com/charmbracelet/bubbletea" + "github.com/stretchr/testify/require" ) // TestMultiSelectionCloseConfirmation tests the close confirmation dialog with multiple selected beads @@ -107,13 +108,13 @@ func TestMultiSelectionCloseConfirmation(t *testing.T) { // Check if the dialog shows the correct number of beads if tt.expectedCount == 1 { // For single bead, check title shows "Close Issue" - if !strings.Contains(dialogContent, "Close Issue") { - t.Errorf("%s: Expected 'Close Issue' in dialog for single bead", tt.description) - } + require.True(t, strings.Contains(dialogContent, "Close Issue"), + "%s: Expected 'Close Issue' in dialog for single bead", tt.description) } else { // For multiple beads, check title shows correct count - if tt.expectedCount > 1 && !strings.Contains(dialogContent, "Issues") { - t.Errorf("%s: Expected 'Issues' (plural) in dialog for multiple beads", tt.description) + if tt.expectedCount > 1 { + require.True(t, strings.Contains(dialogContent, "Issues"), + "%s: Expected 'Issues' (plural) in dialog for multiple beads", tt.description) } } @@ -126,24 +127,23 @@ func TestMultiSelectionCloseConfirmation(t *testing.T) { selectedCount++ // Only first 5 beads should be shown if shownCount < 5 { - if !strings.Contains(dialogContent, item.ID) { - t.Errorf("%s: Expected bead ID '%s' to appear in dialog (one of first 5)", tt.description, item.ID) - } + require.True(t, strings.Contains(dialogContent, item.ID), + "%s: Expected bead ID '%s' to appear in dialog (one of first 5)", tt.description, item.ID) shownCount++ } } } // If more than 5 selected, check for ellipsis - if selectedCount > 5 && !strings.Contains(dialogContent, "and") && !strings.Contains(dialogContent, "more") { - t.Errorf("%s: Expected '... and X more' for more than 5 selected beads", tt.description) + if selectedCount > 5 { + require.True(t, strings.Contains(dialogContent, "and") || strings.Contains(dialogContent, "more"), + "%s: Expected '... and X more' for more than 5 selected beads", tt.description) } } // Check dialog has confirmation buttons - if !strings.Contains(dialogContent, "[y]") || !strings.Contains(dialogContent, "[n]") { - t.Errorf("%s: Expected confirmation buttons [y] and [n] in dialog", tt.description) - } + require.True(t, strings.Contains(dialogContent, "[y]") && strings.Contains(dialogContent, "[n]"), + "%s: Expected confirmation buttons [y] and [n] in dialog", tt.description) }) } } @@ -224,14 +224,14 @@ func TestUpdateCloseBeadConfirm(t *testing.T) { // Check if view mode changed back to normal if tt.shouldCancel || tt.shouldClose { - if updatedModel.viewMode != ViewNormal { - t.Errorf("%s: Expected viewMode to be ViewNormal after action, got %v", tt.description, updatedModel.viewMode) - } + require.Equal(t, ViewNormal, updatedModel.viewMode, + "%s: Expected viewMode to be ViewNormal after action", tt.description) } // If close was confirmed, a command should be returned - if tt.shouldClose && cmd == nil { - t.Errorf("%s: Expected a command to be returned when confirming close", tt.description) + if tt.shouldClose { + require.NotNil(t, cmd, + "%s: Expected a command to be returned when confirming close", tt.description) } }) } @@ -308,9 +308,8 @@ func TestCloseKeyHandlerWithSelection(t *testing.T) { // Check if dialog was shown as expected dialogShown := m.viewMode == ViewCloseBeadConfirm - if dialogShown != tt.shouldShowDialog { - t.Errorf("%s: Expected dialog shown=%v, got %v", tt.description, tt.shouldShowDialog, dialogShown) - } + require.Equal(t, tt.shouldShowDialog, dialogShown, + "%s: dialog shown state mismatch", tt.description) }) } } @@ -333,9 +332,7 @@ func TestBatchCloseFunction(t *testing.T) { cmd := m.closeBeads(beadIDs) // Verify the command is not nil - if cmd == nil { - t.Error("closeBeads should return a non-nil command") - } + require.NotNil(t, cmd, "closeBeads should return a non-nil command") // In a real scenario, we would verify that the bd command is called with all IDs: // Expected: bd close bead-1 bead-2 bead-3 @@ -404,25 +401,15 @@ func TestCloseConfirmationEdgeCases(t *testing.T) { m := tt.setup() // Test dialog rendering doesn't panic - func() { - defer func() { - if r := recover(); r != nil { - t.Errorf("%s: Panic occurred: %v", tt.name, r) - } - }() + require.NotPanics(t, func() { _ = m.renderCloseBeadConfirmContent() - }() + }, "%s: Panic occurred during dialog rendering", tt.name) // Test update function doesn't panic when confirming - func() { - defer func() { - if r := recover(); r != nil { - t.Errorf("%s: Panic on confirm: %v", tt.name, r) - } - }() + require.NotPanics(t, func() { keyMsg := tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("y")} _, _ = m.updateCloseBeadConfirm(keyMsg) - }() + }, "%s: Panic on confirm", tt.name) }) } } diff --git a/internal/tui/tui_plan_tree_test.go b/internal/tui/tui_plan_tree_test.go index f6813c34..99f09d67 100644 --- a/internal/tui/tui_plan_tree_test.go +++ b/internal/tui/tui_plan_tree_test.go @@ -3,6 +3,8 @@ package tui import ( "context" "testing" + + "github.com/stretchr/testify/require" ) // TestBuildBeadTree_EpicHierarchy tests handling of epic (parent-child) relationships @@ -16,18 +18,14 @@ func TestBuildBeadTree_EpicHierarchy(t *testing.T) { result := buildBeadTree(context.Background(), items, nil) // Verify epic is root and tasks are children - if len(result) != 3 { - t.Fatalf("expected 3 items, got %d", len(result)) - } + require.Len(t, result, 3) - if result[0].ID != "epic-1" || result[0].treeDepth != 0 { - t.Errorf("expected epic-1 at root level, got %s at depth %d", result[0].ID, result[0].treeDepth) - } + require.Equal(t, "epic-1", result[0].ID) + require.Equal(t, 0, result[0].treeDepth, "expected epic-1 at root level") // Both tasks should be at depth 1 - if result[1].treeDepth != 1 || result[2].treeDepth != 1 { - t.Errorf("expected tasks at depth 1, got depths %d and %d", result[1].treeDepth, result[2].treeDepth) - } + require.Equal(t, 1, result[1].treeDepth, "expected task at depth 1") + require.Equal(t, 1, result[2].treeDepth, "expected task at depth 1") } // TestBuildBeadTree_BlocksDependencies tests handling of "blocks" type dependencies @@ -39,17 +37,12 @@ func TestBuildBeadTree_BlocksDependencies(t *testing.T) { result := buildBeadTree(context.Background(), items, nil) - if len(result) != 2 { - t.Fatalf("expected 2 items, got %d", len(result)) - } + require.Len(t, result, 2) // Blocker should be root, blocked should be child - if result[0].ID != "blocker" { - t.Errorf("expected blocker first, got %s", result[0].ID) - } - if result[1].ID != "blocked" || result[1].treeDepth != 1 { - t.Errorf("expected blocked at depth 1, got %s at depth %d", result[1].ID, result[1].treeDepth) - } + require.Equal(t, "blocker", result[0].ID) + require.Equal(t, "blocked", result[1].ID) + require.Equal(t, 1, result[1].treeDepth, "expected blocked at depth 1") } // TestBuildBeadTree_ClosedParentVisibility tests filtering of closed parents @@ -62,9 +55,7 @@ func TestBuildBeadTree_ClosedParentVisibility(t *testing.T) { result := buildBeadTree(context.Background(), items, nil) // Both parent and child should be visible since parent has visible child - if len(result) != 2 { - t.Errorf("expected both parent and child visible, got %d items", len(result)) - } + require.Len(t, result, 2, "expected both parent and child visible") } // TestBuildBeadTree_ClosedParentNoVisibleChildren tests filtering out closed parents without visible children @@ -76,9 +67,7 @@ func TestBuildBeadTree_ClosedParentNoVisibleChildren(t *testing.T) { result := buildBeadTree(context.Background(), items, nil) // Parent should be filtered out since it has no visible children - if len(result) != 0 { - t.Errorf("expected closed parent without children to be filtered out, got %d items", len(result)) - } + require.Empty(t, result, "expected closed parent without children to be filtered out") } // TestBuildBeadTree_MultiLevelNesting tests deep hierarchy @@ -92,16 +81,12 @@ func TestBuildBeadTree_MultiLevelNesting(t *testing.T) { result := buildBeadTree(context.Background(), items, nil) - if len(result) != 4 { - t.Fatalf("expected 4 items, got %d", len(result)) - } + require.Len(t, result, 4) // Verify each level has correct depth expectedDepths := []int{0, 1, 2, 3} for i, item := range result { - if item.treeDepth != expectedDepths[i] { - t.Errorf("item %s expected depth %d, got %d", item.ID, expectedDepths[i], item.treeDepth) - } + require.Equal(t, expectedDepths[i], item.treeDepth, "item %s has wrong depth", item.ID) } } @@ -116,9 +101,7 @@ func TestBuildBeadTree_MultipleRoots(t *testing.T) { result := buildBeadTree(context.Background(), items, nil) - if len(result) != 4 { - t.Fatalf("expected 4 items, got %d", len(result)) - } + require.Len(t, result, 4) // Count roots (depth 0) rootCount := 0 @@ -128,9 +111,7 @@ func TestBuildBeadTree_MultipleRoots(t *testing.T) { } } - if rootCount != 2 { - t.Errorf("expected 2 roots, got %d", rootCount) - } + require.Equal(t, 2, rootCount, "expected 2 roots") } // TestBuildBeadTree_MixedTypes tests handling of different dependency types together @@ -144,9 +125,7 @@ func TestBuildBeadTree_MixedTypes(t *testing.T) { result := buildBeadTree(context.Background(), items, nil) - if len(result) != 4 { - t.Fatalf("expected 4 items, got %d", len(result)) - } + require.Len(t, result, 4) // Verify mixed types are handled correctly rootTypes := make(map[string]bool) @@ -156,9 +135,7 @@ func TestBuildBeadTree_MixedTypes(t *testing.T) { } } - if len(rootTypes) < 2 { - t.Errorf("expected multiple types at root level") - } + require.GreaterOrEqual(t, len(rootTypes), 2, "expected multiple types at root level") } // TestBuildBeadTree_CircularDependencies tests handling of circular dependency detection @@ -173,9 +150,7 @@ func TestBuildBeadTree_CircularDependencies(t *testing.T) { result := buildBeadTree(context.Background(), items, nil) // Should still produce all 3 items - if len(result) != 3 { - t.Fatalf("expected 3 items despite circular dependency, got %d", len(result)) - } + require.Len(t, result, 3, "expected 3 items despite circular dependency") } // TestBuildBeadTree_EmptyInput tests handling of empty input @@ -183,9 +158,7 @@ func TestBuildBeadTree_EmptyInput(t *testing.T) { items := []beadItem{} result := buildBeadTree(context.Background(), items, nil) - if len(result) != 0 { - t.Errorf("expected empty result for empty input, got %d items", len(result)) - } + require.Empty(t, result, "expected empty result for empty input") } // TestBuildBeadTree_WithNilClient tests that the function works with nil client @@ -197,9 +170,7 @@ func TestBuildBeadTree_WithNilClient(t *testing.T) { result := buildBeadTree(context.Background(), items, nil) - if len(result) != 1 { - t.Fatalf("expected 1 item, got %d", len(result)) - } + require.Len(t, result, 1) } // TestBuildBeadTree_ParentChildRelationship tests that parent-child relationships are preserved @@ -213,7 +184,5 @@ func TestBuildBeadTree_ParentChildRelationship(t *testing.T) { result := buildBeadTree(context.Background(), items, nil) // Both parent and child should be visible - if len(result) != 2 { - t.Errorf("expected parent to be visible with open child, got %d items", len(result)) - } + require.Len(t, result, 2, "expected parent to be visible with open child") } diff --git a/internal/work/import_pr_test.go b/internal/work/import_pr_test.go index ac9d4187..76426bcb 100644 --- a/internal/work/import_pr_test.go +++ b/internal/work/import_pr_test.go @@ -3,222 +3,37 @@ package work import ( "context" "errors" + "strings" "testing" "time" + "github.com/newhook/co/internal/git" "github.com/newhook/co/internal/github" "github.com/newhook/co/internal/worktree" + "github.com/stretchr/testify/require" ) -// MockClientInterface implements github.ClientInterface for testing. -type MockClientInterface struct { - GetPRStatusFunc func(ctx context.Context, prURL string) (*github.PRStatus, error) - GetPRMetadataFunc func(ctx context.Context, prURLOrNumber string, repo string) (*github.PRMetadata, error) - PostPRCommentFunc func(ctx context.Context, prURL string, body string) error - PostReplyToCommentFunc func(ctx context.Context, prURL string, commentID int, body string) error - PostReviewReplyFunc func(ctx context.Context, prURL string, reviewCommentID int, body string) error - ResolveReviewThreadFunc func(ctx context.Context, prURL string, commentID int) error - GetJobLogsFunc func(ctx context.Context, repo string, jobID int64) (string, error) -} - -func (m *MockClientInterface) GetPRStatus(ctx context.Context, prURL string) (*github.PRStatus, error) { - if m.GetPRStatusFunc != nil { - return m.GetPRStatusFunc(ctx, prURL) - } - return nil, errors.New("GetPRStatus not implemented") -} - -func (m *MockClientInterface) GetPRMetadata(ctx context.Context, prURLOrNumber string, repo string) (*github.PRMetadata, error) { - if m.GetPRMetadataFunc != nil { - return m.GetPRMetadataFunc(ctx, prURLOrNumber, repo) - } - return nil, errors.New("GetPRMetadata not implemented") -} - -func (m *MockClientInterface) PostPRComment(ctx context.Context, prURL string, body string) error { - if m.PostPRCommentFunc != nil { - return m.PostPRCommentFunc(ctx, prURL, body) - } - return errors.New("PostPRComment not implemented") -} - -func (m *MockClientInterface) PostReplyToComment(ctx context.Context, prURL string, commentID int, body string) error { - if m.PostReplyToCommentFunc != nil { - return m.PostReplyToCommentFunc(ctx, prURL, commentID, body) - } - return errors.New("PostReplyToComment not implemented") -} - -func (m *MockClientInterface) PostReviewReply(ctx context.Context, prURL string, reviewCommentID int, body string) error { - if m.PostReviewReplyFunc != nil { - return m.PostReviewReplyFunc(ctx, prURL, reviewCommentID, body) - } - return errors.New("PostReviewReply not implemented") -} - -func (m *MockClientInterface) ResolveReviewThread(ctx context.Context, prURL string, commentID int) error { - if m.ResolveReviewThreadFunc != nil { - return m.ResolveReviewThreadFunc(ctx, prURL, commentID) - } - return errors.New("ResolveReviewThread not implemented") -} - -func (m *MockClientInterface) GetJobLogs(ctx context.Context, repo string, jobID int64) (string, error) { - if m.GetJobLogsFunc != nil { - return m.GetJobLogsFunc(ctx, repo, jobID) - } - return "", errors.New("GetJobLogs not implemented") -} - -// MockGitOperations implements git.Operations for testing. -type MockGitOperations struct { - PushSetUpstreamFunc func(ctx context.Context, branch, dir string) error - PullFunc func(ctx context.Context, dir string) error - CloneFunc func(ctx context.Context, source, dest string) error - FetchBranchFunc func(ctx context.Context, repoPath, branch string) error - FetchPRRefFunc func(ctx context.Context, repoPath string, prNumber int, localBranch string) error - BranchExistsFunc func(ctx context.Context, repoPath, branchName string) bool - ValidateExistingBranchFunc func(ctx context.Context, repoPath, branchName string) (bool, bool, error) - ListBranchesFunc func(ctx context.Context, repoPath string) ([]string, error) -} - -func (m *MockGitOperations) PushSetUpstream(ctx context.Context, branch, dir string) error { - if m.PushSetUpstreamFunc != nil { - return m.PushSetUpstreamFunc(ctx, branch, dir) - } - return nil -} - -func (m *MockGitOperations) Pull(ctx context.Context, dir string) error { - if m.PullFunc != nil { - return m.PullFunc(ctx, dir) - } - return nil -} - -func (m *MockGitOperations) Clone(ctx context.Context, source, dest string) error { - if m.CloneFunc != nil { - return m.CloneFunc(ctx, source, dest) - } - return nil -} - -func (m *MockGitOperations) FetchBranch(ctx context.Context, repoPath, branch string) error { - if m.FetchBranchFunc != nil { - return m.FetchBranchFunc(ctx, repoPath, branch) - } - return nil -} - -func (m *MockGitOperations) FetchPRRef(ctx context.Context, repoPath string, prNumber int, localBranch string) error { - if m.FetchPRRefFunc != nil { - return m.FetchPRRefFunc(ctx, repoPath, prNumber, localBranch) - } - return nil -} - -func (m *MockGitOperations) BranchExists(ctx context.Context, repoPath, branchName string) bool { - if m.BranchExistsFunc != nil { - return m.BranchExistsFunc(ctx, repoPath, branchName) - } - return false -} - -func (m *MockGitOperations) ValidateExistingBranch(ctx context.Context, repoPath, branchName string) (bool, bool, error) { - if m.ValidateExistingBranchFunc != nil { - return m.ValidateExistingBranchFunc(ctx, repoPath, branchName) - } - return false, false, nil -} - -func (m *MockGitOperations) ListBranches(ctx context.Context, repoPath string) ([]string, error) { - if m.ListBranchesFunc != nil { - return m.ListBranchesFunc(ctx, repoPath) - } - return nil, nil -} - -// MockWorktreeOperations implements worktree.Operations for testing. -type MockWorktreeOperations struct { - CreateFunc func(ctx context.Context, repoPath, worktreePath, branch, baseBranch string) error - CreateFromExistingFunc func(ctx context.Context, repoPath, worktreePath, branch string) error - RemoveForceFunc func(ctx context.Context, repoPath, worktreePath string) error - ListFunc func(ctx context.Context, repoPath string) ([]worktree.Worktree, error) - ExistsPathFunc func(worktreePath string) bool -} - -func (m *MockWorktreeOperations) Create(ctx context.Context, repoPath, worktreePath, branch, baseBranch string) error { - if m.CreateFunc != nil { - return m.CreateFunc(ctx, repoPath, worktreePath, branch, baseBranch) - } - return nil -} - -func (m *MockWorktreeOperations) CreateFromExisting(ctx context.Context, repoPath, worktreePath, branch string) error { - if m.CreateFromExistingFunc != nil { - return m.CreateFromExistingFunc(ctx, repoPath, worktreePath, branch) - } - return nil -} - -func (m *MockWorktreeOperations) RemoveForce(ctx context.Context, repoPath, worktreePath string) error { - if m.RemoveForceFunc != nil { - return m.RemoveForceFunc(ctx, repoPath, worktreePath) - } - return nil -} - -func (m *MockWorktreeOperations) List(ctx context.Context, repoPath string) ([]worktree.Worktree, error) { - if m.ListFunc != nil { - return m.ListFunc(ctx, repoPath) - } - return nil, nil -} - -func (m *MockWorktreeOperations) ExistsPath(worktreePath string) bool { - if m.ExistsPathFunc != nil { - return m.ExistsPathFunc(worktreePath) - } - return false -} - func TestNewPRImporter(t *testing.T) { - client := &MockClientInterface{} + client := &github.GitHubClientMock{} importer := NewPRImporter(client) - if importer == nil { - t.Fatal("NewPRImporter returned nil") - } - if importer.client != client { - t.Error("client not set correctly") - } - if importer.gitOps == nil { - t.Error("gitOps should be initialized") - } - if importer.worktreeOps == nil { - t.Error("worktreeOps should be initialized") - } + require.NotNil(t, importer, "NewPRImporter returned nil") + require.Equal(t, client, importer.client, "client not set correctly") + require.NotNil(t, importer.gitOps, "gitOps should be initialized") + require.NotNil(t, importer.worktreeOps, "worktreeOps should be initialized") } func TestNewPRImporterWithOps(t *testing.T) { - client := &MockClientInterface{} - gitOps := &MockGitOperations{} - worktreeOps := &MockWorktreeOperations{} + client := &github.GitHubClientMock{} + gitOps := &git.GitOperationsMock{} + worktreeOps := &worktree.WorktreeOperationsMock{} importer := NewPRImporterWithOps(client, gitOps, worktreeOps) - if importer == nil { - t.Fatal("NewPRImporterWithOps returned nil") - } - if importer.client != client { - t.Error("client not set correctly") - } - if importer.gitOps == nil { - t.Error("gitOps should be set") - } - if importer.worktreeOps == nil { - t.Error("worktreeOps should be set") - } + require.NotNil(t, importer, "NewPRImporterWithOps returned nil") + require.Equal(t, client, importer.client, "client not set correctly") + require.NotNil(t, importer.gitOps, "gitOps should be set") + require.NotNil(t, importer.worktreeOps, "worktreeOps should be set") } func TestSetupWorktreeFromPR_Success(t *testing.T) { @@ -232,36 +47,28 @@ func TestSetupWorktreeFromPR_Success(t *testing.T) { BaseRefName: "main", } - client := &MockClientInterface{ + client := &github.GitHubClientMock{ GetPRMetadataFunc: func(ctx context.Context, prURLOrNumber string, repo string) (*github.PRMetadata, error) { return metadata, nil }, } fetchPRRefCalled := false - gitOps := &MockGitOperations{ + gitOps := &git.GitOperationsMock{ FetchPRRefFunc: func(ctx context.Context, repoPath string, prNumber int, localBranch string) error { fetchPRRefCalled = true - if prNumber != 123 { - t.Errorf("expected PR number 123, got %d", prNumber) - } - if localBranch != "feature-branch" { - t.Errorf("expected branch 'feature-branch', got %s", localBranch) - } + require.Equal(t, 123, prNumber) + require.Equal(t, "feature-branch", localBranch) return nil }, } createFromExistingCalled := false - worktreeOps := &MockWorktreeOperations{ + worktreeOps := &worktree.WorktreeOperationsMock{ CreateFromExistingFunc: func(ctx context.Context, repoPath, worktreePath, branch string) error { createFromExistingCalled = true - if branch != "feature-branch" { - t.Errorf("expected branch 'feature-branch', got %s", branch) - } - if worktreePath != "/work/dir/tree" { - t.Errorf("expected worktreePath '/work/dir/tree', got %s", worktreePath) - } + require.Equal(t, "feature-branch", branch) + require.Equal(t, "/work/dir/tree", worktreePath) return nil }, } @@ -270,25 +77,11 @@ func TestSetupWorktreeFromPR_Success(t *testing.T) { resultMetadata, worktreePath, err := importer.SetupWorktreeFromPR(ctx, "/repo/path", "https://github.com/owner/repo/pull/123", "", "/work/dir", "") - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if !fetchPRRefCalled { - t.Error("FetchPRRef was not called") - } - - if !createFromExistingCalled { - t.Error("CreateFromExisting was not called") - } - - if resultMetadata.Number != 123 { - t.Errorf("expected PR number 123, got %d", resultMetadata.Number) - } - - if worktreePath != "/work/dir/tree" { - t.Errorf("expected worktreePath '/work/dir/tree', got %s", worktreePath) - } + require.NoError(t, err) + require.True(t, fetchPRRefCalled, "FetchPRRef was not called") + require.True(t, createFromExistingCalled, "CreateFromExisting was not called") + require.Equal(t, 123, resultMetadata.Number) + require.Equal(t, "/work/dir/tree", worktreePath) } func TestSetupWorktreeFromPR_CustomBranchName(t *testing.T) { @@ -302,26 +95,22 @@ func TestSetupWorktreeFromPR_CustomBranchName(t *testing.T) { BaseRefName: "main", } - client := &MockClientInterface{ + client := &github.GitHubClientMock{ GetPRMetadataFunc: func(ctx context.Context, prURLOrNumber string, repo string) (*github.PRMetadata, error) { return metadata, nil }, } - gitOps := &MockGitOperations{ + gitOps := &git.GitOperationsMock{ FetchPRRefFunc: func(ctx context.Context, repoPath string, prNumber int, localBranch string) error { - if localBranch != "custom-branch" { - t.Errorf("expected branch 'custom-branch', got %s", localBranch) - } + require.Equal(t, "custom-branch", localBranch) return nil }, } - worktreeOps := &MockWorktreeOperations{ + worktreeOps := &worktree.WorktreeOperationsMock{ CreateFromExistingFunc: func(ctx context.Context, repoPath, worktreePath, branch string) error { - if branch != "custom-branch" { - t.Errorf("expected branch 'custom-branch', got %s", branch) - } + require.Equal(t, "custom-branch", branch) return nil }, } @@ -330,34 +119,27 @@ func TestSetupWorktreeFromPR_CustomBranchName(t *testing.T) { _, _, err := importer.SetupWorktreeFromPR(ctx, "/repo/path", "https://github.com/owner/repo/pull/123", "", "/work/dir", "custom-branch") - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + require.NoError(t, err) } func TestSetupWorktreeFromPR_MetadataError(t *testing.T) { ctx := context.Background() - client := &MockClientInterface{ + client := &github.GitHubClientMock{ GetPRMetadataFunc: func(ctx context.Context, prURLOrNumber string, repo string) (*github.PRMetadata, error) { return nil, errors.New("API error") }, } - gitOps := &MockGitOperations{} - worktreeOps := &MockWorktreeOperations{} + gitOps := &git.GitOperationsMock{} + worktreeOps := &worktree.WorktreeOperationsMock{} importer := NewPRImporterWithOps(client, gitOps, worktreeOps) _, _, err := importer.SetupWorktreeFromPR(ctx, "/repo/path", "https://github.com/owner/repo/pull/123", "", "/work/dir", "") - if err == nil { - t.Fatal("expected error, got nil") - } - - if !errors.Is(err, errors.New("API error")) && err.Error() != "failed to get PR metadata: API error" { - t.Errorf("unexpected error message: %v", err) - } + require.Error(t, err) + require.Equal(t, "failed to get PR metadata: API error", err.Error()) } func TestSetupWorktreeFromPR_FetchPRRefError(t *testing.T) { @@ -368,32 +150,27 @@ func TestSetupWorktreeFromPR_FetchPRRefError(t *testing.T) { HeadRefName: "feature-branch", } - client := &MockClientInterface{ + client := &github.GitHubClientMock{ GetPRMetadataFunc: func(ctx context.Context, prURLOrNumber string, repo string) (*github.PRMetadata, error) { return metadata, nil }, } - gitOps := &MockGitOperations{ + gitOps := &git.GitOperationsMock{ FetchPRRefFunc: func(ctx context.Context, repoPath string, prNumber int, localBranch string) error { return errors.New("fetch failed") }, } - worktreeOps := &MockWorktreeOperations{} + worktreeOps := &worktree.WorktreeOperationsMock{} importer := NewPRImporterWithOps(client, gitOps, worktreeOps) resultMetadata, _, err := importer.SetupWorktreeFromPR(ctx, "/repo/path", "https://github.com/owner/repo/pull/123", "", "/work/dir", "") - if err == nil { - t.Fatal("expected error, got nil") - } - + require.Error(t, err) // Metadata should still be returned on fetch failure - if resultMetadata == nil { - t.Error("metadata should be returned even on fetch failure") - } + require.NotNil(t, resultMetadata, "metadata should be returned even on fetch failure") } func TestSetupWorktreeFromPR_WorktreeCreateError(t *testing.T) { @@ -404,19 +181,19 @@ func TestSetupWorktreeFromPR_WorktreeCreateError(t *testing.T) { HeadRefName: "feature-branch", } - client := &MockClientInterface{ + client := &github.GitHubClientMock{ GetPRMetadataFunc: func(ctx context.Context, prURLOrNumber string, repo string) (*github.PRMetadata, error) { return metadata, nil }, } - gitOps := &MockGitOperations{ + gitOps := &git.GitOperationsMock{ FetchPRRefFunc: func(ctx context.Context, repoPath string, prNumber int, localBranch string) error { return nil }, } - worktreeOps := &MockWorktreeOperations{ + worktreeOps := &worktree.WorktreeOperationsMock{ CreateFromExistingFunc: func(ctx context.Context, repoPath, worktreePath, branch string) error { return errors.New("worktree create failed") }, @@ -426,14 +203,9 @@ func TestSetupWorktreeFromPR_WorktreeCreateError(t *testing.T) { resultMetadata, _, err := importer.SetupWorktreeFromPR(ctx, "/repo/path", "https://github.com/owner/repo/pull/123", "", "/work/dir", "") - if err == nil { - t.Fatal("expected error, got nil") - } - + require.Error(t, err) // Metadata should still be returned on worktree create failure - if resultMetadata == nil { - t.Error("metadata should be returned even on worktree create failure") - } + require.NotNil(t, resultMetadata, "metadata should be returned even on worktree create failure") } func TestFetchPRMetadata(t *testing.T) { @@ -444,14 +216,10 @@ func TestFetchPRMetadata(t *testing.T) { Title: "Test PR", } - client := &MockClientInterface{ + client := &github.GitHubClientMock{ GetPRMetadataFunc: func(ctx context.Context, prURLOrNumber string, repo string) (*github.PRMetadata, error) { - if prURLOrNumber != "456" { - t.Errorf("expected prURLOrNumber '456', got %s", prURLOrNumber) - } - if repo != "owner/repo" { - t.Errorf("expected repo 'owner/repo', got %s", repo) - } + require.Equal(t, "456", prURLOrNumber) + require.Equal(t, "owner/repo", repo) return expectedMetadata, nil }, } @@ -460,13 +228,8 @@ func TestFetchPRMetadata(t *testing.T) { metadata, err := importer.FetchPRMetadata(ctx, "456", "owner/repo") - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if metadata.Number != 456 { - t.Errorf("expected PR number 456, got %d", metadata.Number) - } + require.NoError(t, err) + require.Equal(t, 456, metadata.Number) } func TestMapPRToBeadCreate(t *testing.T) { @@ -485,42 +248,23 @@ func TestMapPRToBeadCreate(t *testing.T) { opts := mapPRToBeadCreate(pr) - if opts.title != "Add new feature" { - t.Errorf("expected title 'Add new feature', got %s", opts.title) - } - - if opts.description != "This PR adds a new feature" { - t.Errorf("expected description 'This PR adds a new feature', got %s", opts.description) - } + require.Equal(t, "Add new feature", opts.title) + require.Equal(t, "This PR adds a new feature", opts.description) // Should detect feature type from labels - if opts.issueType != "feature" { - t.Errorf("expected type 'feature', got %s", opts.issueType) - } + require.Equal(t, "feature", opts.issueType) // Should have default P2 priority - if opts.priority != "P2" { - t.Errorf("expected priority 'P2', got %s", opts.priority) - } + require.Equal(t, "P2", opts.priority) // Labels should be passed through - if len(opts.labels) != 2 { - t.Errorf("expected 2 labels, got %d", len(opts.labels)) - } + require.Len(t, opts.labels, 2) // Metadata should contain PR info - if opts.metadata["pr_url"] != "https://github.com/owner/repo/pull/123" { - t.Error("pr_url metadata not set correctly") - } - if opts.metadata["pr_number"] != "123" { - t.Error("pr_number metadata not set correctly") - } - if opts.metadata["pr_branch"] != "feature-branch" { - t.Error("pr_branch metadata not set correctly") - } - if opts.metadata["pr_author"] != "testuser" { - t.Error("pr_author metadata not set correctly") - } + require.Equal(t, "https://github.com/owner/repo/pull/123", opts.metadata["pr_url"]) + require.Equal(t, "123", opts.metadata["pr_number"]) + require.Equal(t, "feature-branch", opts.metadata["pr_branch"]) + require.Equal(t, "testuser", opts.metadata["pr_author"]) } func TestMapPRType(t *testing.T) { @@ -606,9 +350,7 @@ func TestMapPRType(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := mapPRType(tt.pr) - if result != tt.expected { - t.Errorf("expected %s, got %s", tt.expected, result) - } + require.Equal(t, tt.expected, result) }) } } @@ -708,9 +450,7 @@ func TestMapPRPriority(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := mapPRPriority(tt.pr) - if result != tt.expected { - t.Errorf("expected %s, got %s", tt.expected, result) - } + require.Equal(t, tt.expected, result) }) } } @@ -776,9 +516,7 @@ func TestMapPRStatus(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := mapPRStatus(tt.pr) - if result != tt.expected { - t.Errorf("expected %s, got %s", tt.expected, result) - } + require.Equal(t, tt.expected, result) }) } } @@ -881,9 +619,8 @@ func TestFormatBeadDescription(t *testing.T) { result := formatBeadDescription(tt.pr) for _, expected := range tt.contains { - if !containsString(result, expected) { - t.Errorf("expected description to contain %q, got:\n%s", expected, result) - } + require.True(t, strings.Contains(result, expected), + "expected description to contain %q, got:\n%s", expected, result) } }) } @@ -910,9 +647,7 @@ func TestParsePriority(t *testing.T) { for _, tt := range tests { t.Run(tt.input, func(t *testing.T) { result := parsePriority(tt.input) - if result != tt.expected { - t.Errorf("parsePriority(%q) = %d, expected %d", tt.input, result, tt.expected) - } + require.Equal(t, tt.expected, result) }) } } @@ -926,21 +661,11 @@ func TestCreateBeadOptions(t *testing.T) { OverridePriority: "P1", } - if opts.BeadsDir != "/path/to/beads" { - t.Error("BeadsDir not set correctly") - } - if !opts.SkipIfExists { - t.Error("SkipIfExists not set correctly") - } - if opts.OverrideTitle != "Custom Title" { - t.Error("OverrideTitle not set correctly") - } - if opts.OverrideType != "bug" { - t.Error("OverrideType not set correctly") - } - if opts.OverridePriority != "P1" { - t.Error("OverridePriority not set correctly") - } + require.Equal(t, "/path/to/beads", opts.BeadsDir) + require.True(t, opts.SkipIfExists) + require.Equal(t, "Custom Title", opts.OverrideTitle) + require.Equal(t, "bug", opts.OverrideType) + require.Equal(t, "P1", opts.OverridePriority) } func TestCreateBeadResult(t *testing.T) { @@ -950,12 +675,8 @@ func TestCreateBeadResult(t *testing.T) { SkipReason: "", } - if result.BeadID != "bead-123" { - t.Error("BeadID not set correctly") - } - if !result.Created { - t.Error("Created not set correctly") - } + require.Equal(t, "bead-123", result.BeadID) + require.True(t, result.Created) // Test skip result skipResult := &CreateBeadResult{ @@ -964,24 +685,7 @@ func TestCreateBeadResult(t *testing.T) { SkipReason: "bead already exists for this PR", } - if skipResult.Created { - t.Error("Created should be false for skipped bead") - } - if skipResult.SkipReason == "" { - t.Error("SkipReason should be set for skipped bead") - } -} - -// containsString checks if s contains substr. -func containsString(s, substr string) bool { - return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsSubstring(s, substr)) + require.False(t, skipResult.Created, "Created should be false for skipped bead") + require.NotEmpty(t, skipResult.SkipReason, "SkipReason should be set for skipped bead") } -func containsSubstring(s, substr string) bool { - for i := 0; i <= len(s)-len(substr); i++ { - if s[i:i+len(substr)] == substr { - return true - } - } - return false -} diff --git a/internal/worktree/worktree.go b/internal/worktree/worktree.go index f34459ad..d9914c1f 100644 --- a/internal/worktree/worktree.go +++ b/internal/worktree/worktree.go @@ -1,5 +1,7 @@ package worktree +//go:generate moq -stub -out worktree_mock.go . Operations:WorktreeOperationsMock + import ( "bufio" "context" diff --git a/internal/worktree/worktree_mock.go b/internal/worktree/worktree_mock.go new file mode 100644 index 00000000..b39412bc --- /dev/null +++ b/internal/worktree/worktree_mock.go @@ -0,0 +1,327 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package worktree + +import ( + "context" + "sync" +) + +// Ensure, that WorktreeOperationsMock does implement Operations. +// If this is not the case, regenerate this file with moq. +var _ Operations = &WorktreeOperationsMock{} + +// WorktreeOperationsMock is a mock implementation of Operations. +// +// func TestSomethingThatUsesOperations(t *testing.T) { +// +// // make and configure a mocked Operations +// mockedOperations := &WorktreeOperationsMock{ +// CreateFunc: func(ctx context.Context, repoPath string, worktreePath string, branch string, baseBranch string) error { +// panic("mock out the Create method") +// }, +// CreateFromExistingFunc: func(ctx context.Context, repoPath string, worktreePath string, branch string) error { +// panic("mock out the CreateFromExisting method") +// }, +// ExistsPathFunc: func(worktreePath string) bool { +// panic("mock out the ExistsPath method") +// }, +// ListFunc: func(ctx context.Context, repoPath string) ([]Worktree, error) { +// panic("mock out the List method") +// }, +// RemoveForceFunc: func(ctx context.Context, repoPath string, worktreePath string) error { +// panic("mock out the RemoveForce method") +// }, +// } +// +// // use mockedOperations in code that requires Operations +// // and then make assertions. +// +// } +type WorktreeOperationsMock struct { + // CreateFunc mocks the Create method. + CreateFunc func(ctx context.Context, repoPath string, worktreePath string, branch string, baseBranch string) error + + // CreateFromExistingFunc mocks the CreateFromExisting method. + CreateFromExistingFunc func(ctx context.Context, repoPath string, worktreePath string, branch string) error + + // ExistsPathFunc mocks the ExistsPath method. + ExistsPathFunc func(worktreePath string) bool + + // ListFunc mocks the List method. + ListFunc func(ctx context.Context, repoPath string) ([]Worktree, error) + + // RemoveForceFunc mocks the RemoveForce method. + RemoveForceFunc func(ctx context.Context, repoPath string, worktreePath string) error + + // calls tracks calls to the methods. + calls struct { + // Create holds details about calls to the Create method. + Create []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // RepoPath is the repoPath argument value. + RepoPath string + // WorktreePath is the worktreePath argument value. + WorktreePath string + // Branch is the branch argument value. + Branch string + // BaseBranch is the baseBranch argument value. + BaseBranch string + } + // CreateFromExisting holds details about calls to the CreateFromExisting method. + CreateFromExisting []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // RepoPath is the repoPath argument value. + RepoPath string + // WorktreePath is the worktreePath argument value. + WorktreePath string + // Branch is the branch argument value. + Branch string + } + // ExistsPath holds details about calls to the ExistsPath method. + ExistsPath []struct { + // WorktreePath is the worktreePath argument value. + WorktreePath string + } + // List holds details about calls to the List method. + List []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // RepoPath is the repoPath argument value. + RepoPath string + } + // RemoveForce holds details about calls to the RemoveForce method. + RemoveForce []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // RepoPath is the repoPath argument value. + RepoPath string + // WorktreePath is the worktreePath argument value. + WorktreePath string + } + } + lockCreate sync.RWMutex + lockCreateFromExisting sync.RWMutex + lockExistsPath sync.RWMutex + lockList sync.RWMutex + lockRemoveForce sync.RWMutex +} + +// Create calls CreateFunc. +func (mock *WorktreeOperationsMock) Create(ctx context.Context, repoPath string, worktreePath string, branch string, baseBranch string) error { + callInfo := struct { + Ctx context.Context + RepoPath string + WorktreePath string + Branch string + BaseBranch string + }{ + Ctx: ctx, + RepoPath: repoPath, + WorktreePath: worktreePath, + Branch: branch, + BaseBranch: baseBranch, + } + mock.lockCreate.Lock() + mock.calls.Create = append(mock.calls.Create, callInfo) + mock.lockCreate.Unlock() + if mock.CreateFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.CreateFunc(ctx, repoPath, worktreePath, branch, baseBranch) +} + +// CreateCalls gets all the calls that were made to Create. +// Check the length with: +// +// len(mockedOperations.CreateCalls()) +func (mock *WorktreeOperationsMock) CreateCalls() []struct { + Ctx context.Context + RepoPath string + WorktreePath string + Branch string + BaseBranch string +} { + var calls []struct { + Ctx context.Context + RepoPath string + WorktreePath string + Branch string + BaseBranch string + } + mock.lockCreate.RLock() + calls = mock.calls.Create + mock.lockCreate.RUnlock() + return calls +} + +// CreateFromExisting calls CreateFromExistingFunc. +func (mock *WorktreeOperationsMock) CreateFromExisting(ctx context.Context, repoPath string, worktreePath string, branch string) error { + callInfo := struct { + Ctx context.Context + RepoPath string + WorktreePath string + Branch string + }{ + Ctx: ctx, + RepoPath: repoPath, + WorktreePath: worktreePath, + Branch: branch, + } + mock.lockCreateFromExisting.Lock() + mock.calls.CreateFromExisting = append(mock.calls.CreateFromExisting, callInfo) + mock.lockCreateFromExisting.Unlock() + if mock.CreateFromExistingFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.CreateFromExistingFunc(ctx, repoPath, worktreePath, branch) +} + +// CreateFromExistingCalls gets all the calls that were made to CreateFromExisting. +// Check the length with: +// +// len(mockedOperations.CreateFromExistingCalls()) +func (mock *WorktreeOperationsMock) CreateFromExistingCalls() []struct { + Ctx context.Context + RepoPath string + WorktreePath string + Branch string +} { + var calls []struct { + Ctx context.Context + RepoPath string + WorktreePath string + Branch string + } + mock.lockCreateFromExisting.RLock() + calls = mock.calls.CreateFromExisting + mock.lockCreateFromExisting.RUnlock() + return calls +} + +// ExistsPath calls ExistsPathFunc. +func (mock *WorktreeOperationsMock) ExistsPath(worktreePath string) bool { + callInfo := struct { + WorktreePath string + }{ + WorktreePath: worktreePath, + } + mock.lockExistsPath.Lock() + mock.calls.ExistsPath = append(mock.calls.ExistsPath, callInfo) + mock.lockExistsPath.Unlock() + if mock.ExistsPathFunc == nil { + var ( + bOut bool + ) + return bOut + } + return mock.ExistsPathFunc(worktreePath) +} + +// ExistsPathCalls gets all the calls that were made to ExistsPath. +// Check the length with: +// +// len(mockedOperations.ExistsPathCalls()) +func (mock *WorktreeOperationsMock) ExistsPathCalls() []struct { + WorktreePath string +} { + var calls []struct { + WorktreePath string + } + mock.lockExistsPath.RLock() + calls = mock.calls.ExistsPath + mock.lockExistsPath.RUnlock() + return calls +} + +// List calls ListFunc. +func (mock *WorktreeOperationsMock) List(ctx context.Context, repoPath string) ([]Worktree, error) { + callInfo := struct { + Ctx context.Context + RepoPath string + }{ + Ctx: ctx, + RepoPath: repoPath, + } + mock.lockList.Lock() + mock.calls.List = append(mock.calls.List, callInfo) + mock.lockList.Unlock() + if mock.ListFunc == nil { + var ( + worktreesOut []Worktree + errOut error + ) + return worktreesOut, errOut + } + return mock.ListFunc(ctx, repoPath) +} + +// ListCalls gets all the calls that were made to List. +// Check the length with: +// +// len(mockedOperations.ListCalls()) +func (mock *WorktreeOperationsMock) ListCalls() []struct { + Ctx context.Context + RepoPath string +} { + var calls []struct { + Ctx context.Context + RepoPath string + } + mock.lockList.RLock() + calls = mock.calls.List + mock.lockList.RUnlock() + return calls +} + +// RemoveForce calls RemoveForceFunc. +func (mock *WorktreeOperationsMock) RemoveForce(ctx context.Context, repoPath string, worktreePath string) error { + callInfo := struct { + Ctx context.Context + RepoPath string + WorktreePath string + }{ + Ctx: ctx, + RepoPath: repoPath, + WorktreePath: worktreePath, + } + mock.lockRemoveForce.Lock() + mock.calls.RemoveForce = append(mock.calls.RemoveForce, callInfo) + mock.lockRemoveForce.Unlock() + if mock.RemoveForceFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.RemoveForceFunc(ctx, repoPath, worktreePath) +} + +// RemoveForceCalls gets all the calls that were made to RemoveForce. +// Check the length with: +// +// len(mockedOperations.RemoveForceCalls()) +func (mock *WorktreeOperationsMock) RemoveForceCalls() []struct { + Ctx context.Context + RepoPath string + WorktreePath string +} { + var calls []struct { + Ctx context.Context + RepoPath string + WorktreePath string + } + mock.lockRemoveForce.RLock() + calls = mock.calls.RemoveForce + mock.lockRemoveForce.RUnlock() + return calls +} diff --git a/internal/worktree/worktree_test.go b/internal/worktree/worktree_test.go new file mode 100644 index 00000000..8a6dd2bf --- /dev/null +++ b/internal/worktree/worktree_test.go @@ -0,0 +1,116 @@ +package worktree + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewOperations(t *testing.T) { + ops := NewOperations() + require.NotNil(t, ops, "NewOperations returned nil") + + // Verify it returns a CLIOperations + _, ok := ops.(*CLIOperations) + require.True(t, ok, "NewOperations should return *CLIOperations") +} + +func TestCLIOperationsImplementsInterface(t *testing.T) { + // Compile-time check that CLIOperations implements Operations + var _ Operations = (*CLIOperations)(nil) +} + +func TestParseWorktreeList(t *testing.T) { + tests := []struct { + name string + input string + expected []Worktree + }{ + { + name: "empty output", + input: "", + expected: nil, + }, + { + name: "single worktree", + input: `worktree /path/to/main +HEAD abc123def456 +branch refs/heads/main +`, + expected: []Worktree{ + {Path: "/path/to/main", HEAD: "abc123def456", Branch: "main"}, + }, + }, + { + name: "multiple worktrees", + input: `worktree /path/to/main +HEAD abc123def456 +branch refs/heads/main + +worktree /path/to/feature +HEAD def789ghi012 +branch refs/heads/feature-branch + +`, + expected: []Worktree{ + {Path: "/path/to/main", HEAD: "abc123def456", Branch: "main"}, + {Path: "/path/to/feature", HEAD: "def789ghi012", Branch: "feature-branch"}, + }, + }, + { + name: "detached HEAD worktree", + input: `worktree /path/to/detached +HEAD abc123def456 +detached + +`, + expected: []Worktree{ + {Path: "/path/to/detached", HEAD: "abc123def456", Branch: ""}, + }, + }, + { + name: "no trailing newline", + input: `worktree /path/to/main +HEAD abc123def456 +branch refs/heads/main`, + expected: []Worktree{ + {Path: "/path/to/main", HEAD: "abc123def456", Branch: "main"}, + }, + }, + { + name: "path with spaces", + input: `worktree /path/with spaces/to/main +HEAD abc123def456 +branch refs/heads/main +`, + expected: []Worktree{ + {Path: "/path/with spaces/to/main", HEAD: "abc123def456", Branch: "main"}, + }, + }, + { + name: "branch with slashes", + input: `worktree /path/to/feature +HEAD abc123def456 +branch refs/heads/feature/sub-feature/deep +`, + expected: []Worktree{ + {Path: "/path/to/feature", HEAD: "abc123def456", Branch: "feature/sub-feature/deep"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := parseWorktreeList(tt.input) + require.NoError(t, err) + + require.Len(t, result, len(tt.expected)) + + for i, expected := range tt.expected { + require.Equal(t, expected.Path, result[i].Path, "worktree[%d] path mismatch", i) + require.Equal(t, expected.HEAD, result[i].HEAD, "worktree[%d] HEAD mismatch", i) + require.Equal(t, expected.Branch, result[i].Branch, "worktree[%d] branch mismatch", i) + } + }) + } +} diff --git a/internal/zellij/zellij.go b/internal/zellij/zellij.go index aec26fa5..7f6a0cdb 100644 --- a/internal/zellij/zellij.go +++ b/internal/zellij/zellij.go @@ -2,6 +2,8 @@ // It abstracts session, tab, and pane management operations into a type-safe API. package zellij +//go:generate moq -stub -out zellij_mock.go . SessionManager Session + import ( "bytes" "context" diff --git a/internal/zellij/zellij_mock.go b/internal/zellij/zellij_mock.go new file mode 100644 index 00000000..c486e6a8 --- /dev/null +++ b/internal/zellij/zellij_mock.go @@ -0,0 +1,1456 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package zellij + +import ( + "context" + "sync" +) + +// Ensure, that SessionManagerMock does implement SessionManager. +// If this is not the case, regenerate this file with moq. +var _ SessionManager = &SessionManagerMock{} + +// SessionManagerMock is a mock implementation of SessionManager. +// +// func TestSomethingThatUsesSessionManager(t *testing.T) { +// +// // make and configure a mocked SessionManager +// mockedSessionManager := &SessionManagerMock{ +// CreateSessionFunc: func(ctx context.Context, name string) error { +// panic("mock out the CreateSession method") +// }, +// CreateSessionWithLayoutFunc: func(ctx context.Context, name string, projectRoot string) error { +// panic("mock out the CreateSessionWithLayout method") +// }, +// DeleteSessionFunc: func(ctx context.Context, name string) error { +// panic("mock out the DeleteSession method") +// }, +// EnsureSessionFunc: func(ctx context.Context, name string) (bool, error) { +// panic("mock out the EnsureSession method") +// }, +// EnsureSessionWithLayoutFunc: func(ctx context.Context, name string, projectRoot string) (bool, error) { +// panic("mock out the EnsureSessionWithLayout method") +// }, +// IsSessionActiveFunc: func(ctx context.Context, name string) (bool, error) { +// panic("mock out the IsSessionActive method") +// }, +// ListSessionsFunc: func(ctx context.Context) ([]string, error) { +// panic("mock out the ListSessions method") +// }, +// SessionFunc: func(name string) Session { +// panic("mock out the Session method") +// }, +// SessionExistsFunc: func(ctx context.Context, name string) (bool, error) { +// panic("mock out the SessionExists method") +// }, +// } +// +// // use mockedSessionManager in code that requires SessionManager +// // and then make assertions. +// +// } +type SessionManagerMock struct { + // CreateSessionFunc mocks the CreateSession method. + CreateSessionFunc func(ctx context.Context, name string) error + + // CreateSessionWithLayoutFunc mocks the CreateSessionWithLayout method. + CreateSessionWithLayoutFunc func(ctx context.Context, name string, projectRoot string) error + + // DeleteSessionFunc mocks the DeleteSession method. + DeleteSessionFunc func(ctx context.Context, name string) error + + // EnsureSessionFunc mocks the EnsureSession method. + EnsureSessionFunc func(ctx context.Context, name string) (bool, error) + + // EnsureSessionWithLayoutFunc mocks the EnsureSessionWithLayout method. + EnsureSessionWithLayoutFunc func(ctx context.Context, name string, projectRoot string) (bool, error) + + // IsSessionActiveFunc mocks the IsSessionActive method. + IsSessionActiveFunc func(ctx context.Context, name string) (bool, error) + + // ListSessionsFunc mocks the ListSessions method. + ListSessionsFunc func(ctx context.Context) ([]string, error) + + // SessionFunc mocks the Session method. + SessionFunc func(name string) Session + + // SessionExistsFunc mocks the SessionExists method. + SessionExistsFunc func(ctx context.Context, name string) (bool, error) + + // calls tracks calls to the methods. + calls struct { + // CreateSession holds details about calls to the CreateSession method. + CreateSession []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Name is the name argument value. + Name string + } + // CreateSessionWithLayout holds details about calls to the CreateSessionWithLayout method. + CreateSessionWithLayout []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Name is the name argument value. + Name string + // ProjectRoot is the projectRoot argument value. + ProjectRoot string + } + // DeleteSession holds details about calls to the DeleteSession method. + DeleteSession []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Name is the name argument value. + Name string + } + // EnsureSession holds details about calls to the EnsureSession method. + EnsureSession []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Name is the name argument value. + Name string + } + // EnsureSessionWithLayout holds details about calls to the EnsureSessionWithLayout method. + EnsureSessionWithLayout []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Name is the name argument value. + Name string + // ProjectRoot is the projectRoot argument value. + ProjectRoot string + } + // IsSessionActive holds details about calls to the IsSessionActive method. + IsSessionActive []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Name is the name argument value. + Name string + } + // ListSessions holds details about calls to the ListSessions method. + ListSessions []struct { + // Ctx is the ctx argument value. + Ctx context.Context + } + // Session holds details about calls to the Session method. + Session []struct { + // Name is the name argument value. + Name string + } + // SessionExists holds details about calls to the SessionExists method. + SessionExists []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Name is the name argument value. + Name string + } + } + lockCreateSession sync.RWMutex + lockCreateSessionWithLayout sync.RWMutex + lockDeleteSession sync.RWMutex + lockEnsureSession sync.RWMutex + lockEnsureSessionWithLayout sync.RWMutex + lockIsSessionActive sync.RWMutex + lockListSessions sync.RWMutex + lockSession sync.RWMutex + lockSessionExists sync.RWMutex +} + +// CreateSession calls CreateSessionFunc. +func (mock *SessionManagerMock) CreateSession(ctx context.Context, name string) error { + callInfo := struct { + Ctx context.Context + Name string + }{ + Ctx: ctx, + Name: name, + } + mock.lockCreateSession.Lock() + mock.calls.CreateSession = append(mock.calls.CreateSession, callInfo) + mock.lockCreateSession.Unlock() + if mock.CreateSessionFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.CreateSessionFunc(ctx, name) +} + +// CreateSessionCalls gets all the calls that were made to CreateSession. +// Check the length with: +// +// len(mockedSessionManager.CreateSessionCalls()) +func (mock *SessionManagerMock) CreateSessionCalls() []struct { + Ctx context.Context + Name string +} { + var calls []struct { + Ctx context.Context + Name string + } + mock.lockCreateSession.RLock() + calls = mock.calls.CreateSession + mock.lockCreateSession.RUnlock() + return calls +} + +// CreateSessionWithLayout calls CreateSessionWithLayoutFunc. +func (mock *SessionManagerMock) CreateSessionWithLayout(ctx context.Context, name string, projectRoot string) error { + callInfo := struct { + Ctx context.Context + Name string + ProjectRoot string + }{ + Ctx: ctx, + Name: name, + ProjectRoot: projectRoot, + } + mock.lockCreateSessionWithLayout.Lock() + mock.calls.CreateSessionWithLayout = append(mock.calls.CreateSessionWithLayout, callInfo) + mock.lockCreateSessionWithLayout.Unlock() + if mock.CreateSessionWithLayoutFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.CreateSessionWithLayoutFunc(ctx, name, projectRoot) +} + +// CreateSessionWithLayoutCalls gets all the calls that were made to CreateSessionWithLayout. +// Check the length with: +// +// len(mockedSessionManager.CreateSessionWithLayoutCalls()) +func (mock *SessionManagerMock) CreateSessionWithLayoutCalls() []struct { + Ctx context.Context + Name string + ProjectRoot string +} { + var calls []struct { + Ctx context.Context + Name string + ProjectRoot string + } + mock.lockCreateSessionWithLayout.RLock() + calls = mock.calls.CreateSessionWithLayout + mock.lockCreateSessionWithLayout.RUnlock() + return calls +} + +// DeleteSession calls DeleteSessionFunc. +func (mock *SessionManagerMock) DeleteSession(ctx context.Context, name string) error { + callInfo := struct { + Ctx context.Context + Name string + }{ + Ctx: ctx, + Name: name, + } + mock.lockDeleteSession.Lock() + mock.calls.DeleteSession = append(mock.calls.DeleteSession, callInfo) + mock.lockDeleteSession.Unlock() + if mock.DeleteSessionFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.DeleteSessionFunc(ctx, name) +} + +// DeleteSessionCalls gets all the calls that were made to DeleteSession. +// Check the length with: +// +// len(mockedSessionManager.DeleteSessionCalls()) +func (mock *SessionManagerMock) DeleteSessionCalls() []struct { + Ctx context.Context + Name string +} { + var calls []struct { + Ctx context.Context + Name string + } + mock.lockDeleteSession.RLock() + calls = mock.calls.DeleteSession + mock.lockDeleteSession.RUnlock() + return calls +} + +// EnsureSession calls EnsureSessionFunc. +func (mock *SessionManagerMock) EnsureSession(ctx context.Context, name string) (bool, error) { + callInfo := struct { + Ctx context.Context + Name string + }{ + Ctx: ctx, + Name: name, + } + mock.lockEnsureSession.Lock() + mock.calls.EnsureSession = append(mock.calls.EnsureSession, callInfo) + mock.lockEnsureSession.Unlock() + if mock.EnsureSessionFunc == nil { + var ( + bOut bool + errOut error + ) + return bOut, errOut + } + return mock.EnsureSessionFunc(ctx, name) +} + +// EnsureSessionCalls gets all the calls that were made to EnsureSession. +// Check the length with: +// +// len(mockedSessionManager.EnsureSessionCalls()) +func (mock *SessionManagerMock) EnsureSessionCalls() []struct { + Ctx context.Context + Name string +} { + var calls []struct { + Ctx context.Context + Name string + } + mock.lockEnsureSession.RLock() + calls = mock.calls.EnsureSession + mock.lockEnsureSession.RUnlock() + return calls +} + +// EnsureSessionWithLayout calls EnsureSessionWithLayoutFunc. +func (mock *SessionManagerMock) EnsureSessionWithLayout(ctx context.Context, name string, projectRoot string) (bool, error) { + callInfo := struct { + Ctx context.Context + Name string + ProjectRoot string + }{ + Ctx: ctx, + Name: name, + ProjectRoot: projectRoot, + } + mock.lockEnsureSessionWithLayout.Lock() + mock.calls.EnsureSessionWithLayout = append(mock.calls.EnsureSessionWithLayout, callInfo) + mock.lockEnsureSessionWithLayout.Unlock() + if mock.EnsureSessionWithLayoutFunc == nil { + var ( + bOut bool + errOut error + ) + return bOut, errOut + } + return mock.EnsureSessionWithLayoutFunc(ctx, name, projectRoot) +} + +// EnsureSessionWithLayoutCalls gets all the calls that were made to EnsureSessionWithLayout. +// Check the length with: +// +// len(mockedSessionManager.EnsureSessionWithLayoutCalls()) +func (mock *SessionManagerMock) EnsureSessionWithLayoutCalls() []struct { + Ctx context.Context + Name string + ProjectRoot string +} { + var calls []struct { + Ctx context.Context + Name string + ProjectRoot string + } + mock.lockEnsureSessionWithLayout.RLock() + calls = mock.calls.EnsureSessionWithLayout + mock.lockEnsureSessionWithLayout.RUnlock() + return calls +} + +// IsSessionActive calls IsSessionActiveFunc. +func (mock *SessionManagerMock) IsSessionActive(ctx context.Context, name string) (bool, error) { + callInfo := struct { + Ctx context.Context + Name string + }{ + Ctx: ctx, + Name: name, + } + mock.lockIsSessionActive.Lock() + mock.calls.IsSessionActive = append(mock.calls.IsSessionActive, callInfo) + mock.lockIsSessionActive.Unlock() + if mock.IsSessionActiveFunc == nil { + var ( + bOut bool + errOut error + ) + return bOut, errOut + } + return mock.IsSessionActiveFunc(ctx, name) +} + +// IsSessionActiveCalls gets all the calls that were made to IsSessionActive. +// Check the length with: +// +// len(mockedSessionManager.IsSessionActiveCalls()) +func (mock *SessionManagerMock) IsSessionActiveCalls() []struct { + Ctx context.Context + Name string +} { + var calls []struct { + Ctx context.Context + Name string + } + mock.lockIsSessionActive.RLock() + calls = mock.calls.IsSessionActive + mock.lockIsSessionActive.RUnlock() + return calls +} + +// ListSessions calls ListSessionsFunc. +func (mock *SessionManagerMock) ListSessions(ctx context.Context) ([]string, error) { + callInfo := struct { + Ctx context.Context + }{ + Ctx: ctx, + } + mock.lockListSessions.Lock() + mock.calls.ListSessions = append(mock.calls.ListSessions, callInfo) + mock.lockListSessions.Unlock() + if mock.ListSessionsFunc == nil { + var ( + stringsOut []string + errOut error + ) + return stringsOut, errOut + } + return mock.ListSessionsFunc(ctx) +} + +// ListSessionsCalls gets all the calls that were made to ListSessions. +// Check the length with: +// +// len(mockedSessionManager.ListSessionsCalls()) +func (mock *SessionManagerMock) ListSessionsCalls() []struct { + Ctx context.Context +} { + var calls []struct { + Ctx context.Context + } + mock.lockListSessions.RLock() + calls = mock.calls.ListSessions + mock.lockListSessions.RUnlock() + return calls +} + +// Session calls SessionFunc. +func (mock *SessionManagerMock) Session(name string) Session { + callInfo := struct { + Name string + }{ + Name: name, + } + mock.lockSession.Lock() + mock.calls.Session = append(mock.calls.Session, callInfo) + mock.lockSession.Unlock() + if mock.SessionFunc == nil { + var ( + sessionOut Session + ) + return sessionOut + } + return mock.SessionFunc(name) +} + +// SessionCalls gets all the calls that were made to Session. +// Check the length with: +// +// len(mockedSessionManager.SessionCalls()) +func (mock *SessionManagerMock) SessionCalls() []struct { + Name string +} { + var calls []struct { + Name string + } + mock.lockSession.RLock() + calls = mock.calls.Session + mock.lockSession.RUnlock() + return calls +} + +// SessionExists calls SessionExistsFunc. +func (mock *SessionManagerMock) SessionExists(ctx context.Context, name string) (bool, error) { + callInfo := struct { + Ctx context.Context + Name string + }{ + Ctx: ctx, + Name: name, + } + mock.lockSessionExists.Lock() + mock.calls.SessionExists = append(mock.calls.SessionExists, callInfo) + mock.lockSessionExists.Unlock() + if mock.SessionExistsFunc == nil { + var ( + bOut bool + errOut error + ) + return bOut, errOut + } + return mock.SessionExistsFunc(ctx, name) +} + +// SessionExistsCalls gets all the calls that were made to SessionExists. +// Check the length with: +// +// len(mockedSessionManager.SessionExistsCalls()) +func (mock *SessionManagerMock) SessionExistsCalls() []struct { + Ctx context.Context + Name string +} { + var calls []struct { + Ctx context.Context + Name string + } + mock.lockSessionExists.RLock() + calls = mock.calls.SessionExists + mock.lockSessionExists.RUnlock() + return calls +} + +// Ensure, that SessionMock does implement Session. +// If this is not the case, regenerate this file with moq. +var _ Session = &SessionMock{} + +// SessionMock is a mock implementation of Session. +// +// func TestSomethingThatUsesSession(t *testing.T) { +// +// // make and configure a mocked Session +// mockedSession := &SessionMock{ +// ClearAndExecuteFunc: func(ctx context.Context, cmd string) error { +// panic("mock out the ClearAndExecute method") +// }, +// CloseTabFunc: func(ctx context.Context) error { +// panic("mock out the CloseTab method") +// }, +// CreateTabFunc: func(ctx context.Context, name string, cwd string) error { +// panic("mock out the CreateTab method") +// }, +// CreateTabWithCommandFunc: func(ctx context.Context, name string, cwd string, command string, args []string, paneName string) error { +// panic("mock out the CreateTabWithCommand method") +// }, +// ExecuteCommandFunc: func(ctx context.Context, cmd string) error { +// panic("mock out the ExecuteCommand method") +// }, +// QueryTabNamesFunc: func(ctx context.Context) ([]string, error) { +// panic("mock out the QueryTabNames method") +// }, +// RunFunc: func(ctx context.Context, name string, cwd string, command ...string) error { +// panic("mock out the Run method") +// }, +// RunFloatingFunc: func(ctx context.Context, name string, cwd string, command ...string) error { +// panic("mock out the RunFloating method") +// }, +// SendCtrlCFunc: func(ctx context.Context) error { +// panic("mock out the SendCtrlC method") +// }, +// SendEnterFunc: func(ctx context.Context) error { +// panic("mock out the SendEnter method") +// }, +// SwitchToTabFunc: func(ctx context.Context, name string) error { +// panic("mock out the SwitchToTab method") +// }, +// TabExistsFunc: func(ctx context.Context, name string) (bool, error) { +// panic("mock out the TabExists method") +// }, +// TerminateAndCloseTabFunc: func(ctx context.Context, tabName string) error { +// panic("mock out the TerminateAndCloseTab method") +// }, +// TerminateProcessFunc: func(ctx context.Context) error { +// panic("mock out the TerminateProcess method") +// }, +// ToggleFloatingPanesFunc: func(ctx context.Context) error { +// panic("mock out the ToggleFloatingPanes method") +// }, +// WriteASCIIFunc: func(ctx context.Context, code int) error { +// panic("mock out the WriteASCII method") +// }, +// WriteCharsFunc: func(ctx context.Context, text string) error { +// panic("mock out the WriteChars method") +// }, +// } +// +// // use mockedSession in code that requires Session +// // and then make assertions. +// +// } +type SessionMock struct { + // ClearAndExecuteFunc mocks the ClearAndExecute method. + ClearAndExecuteFunc func(ctx context.Context, cmd string) error + + // CloseTabFunc mocks the CloseTab method. + CloseTabFunc func(ctx context.Context) error + + // CreateTabFunc mocks the CreateTab method. + CreateTabFunc func(ctx context.Context, name string, cwd string) error + + // CreateTabWithCommandFunc mocks the CreateTabWithCommand method. + CreateTabWithCommandFunc func(ctx context.Context, name string, cwd string, command string, args []string, paneName string) error + + // ExecuteCommandFunc mocks the ExecuteCommand method. + ExecuteCommandFunc func(ctx context.Context, cmd string) error + + // QueryTabNamesFunc mocks the QueryTabNames method. + QueryTabNamesFunc func(ctx context.Context) ([]string, error) + + // RunFunc mocks the Run method. + RunFunc func(ctx context.Context, name string, cwd string, command ...string) error + + // RunFloatingFunc mocks the RunFloating method. + RunFloatingFunc func(ctx context.Context, name string, cwd string, command ...string) error + + // SendCtrlCFunc mocks the SendCtrlC method. + SendCtrlCFunc func(ctx context.Context) error + + // SendEnterFunc mocks the SendEnter method. + SendEnterFunc func(ctx context.Context) error + + // SwitchToTabFunc mocks the SwitchToTab method. + SwitchToTabFunc func(ctx context.Context, name string) error + + // TabExistsFunc mocks the TabExists method. + TabExistsFunc func(ctx context.Context, name string) (bool, error) + + // TerminateAndCloseTabFunc mocks the TerminateAndCloseTab method. + TerminateAndCloseTabFunc func(ctx context.Context, tabName string) error + + // TerminateProcessFunc mocks the TerminateProcess method. + TerminateProcessFunc func(ctx context.Context) error + + // ToggleFloatingPanesFunc mocks the ToggleFloatingPanes method. + ToggleFloatingPanesFunc func(ctx context.Context) error + + // WriteASCIIFunc mocks the WriteASCII method. + WriteASCIIFunc func(ctx context.Context, code int) error + + // WriteCharsFunc mocks the WriteChars method. + WriteCharsFunc func(ctx context.Context, text string) error + + // calls tracks calls to the methods. + calls struct { + // ClearAndExecute holds details about calls to the ClearAndExecute method. + ClearAndExecute []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Cmd is the cmd argument value. + Cmd string + } + // CloseTab holds details about calls to the CloseTab method. + CloseTab []struct { + // Ctx is the ctx argument value. + Ctx context.Context + } + // CreateTab holds details about calls to the CreateTab method. + CreateTab []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Name is the name argument value. + Name string + // Cwd is the cwd argument value. + Cwd string + } + // CreateTabWithCommand holds details about calls to the CreateTabWithCommand method. + CreateTabWithCommand []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Name is the name argument value. + Name string + // Cwd is the cwd argument value. + Cwd string + // Command is the command argument value. + Command string + // Args is the args argument value. + Args []string + // PaneName is the paneName argument value. + PaneName string + } + // ExecuteCommand holds details about calls to the ExecuteCommand method. + ExecuteCommand []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Cmd is the cmd argument value. + Cmd string + } + // QueryTabNames holds details about calls to the QueryTabNames method. + QueryTabNames []struct { + // Ctx is the ctx argument value. + Ctx context.Context + } + // Run holds details about calls to the Run method. + Run []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Name is the name argument value. + Name string + // Cwd is the cwd argument value. + Cwd string + // Command is the command argument value. + Command []string + } + // RunFloating holds details about calls to the RunFloating method. + RunFloating []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Name is the name argument value. + Name string + // Cwd is the cwd argument value. + Cwd string + // Command is the command argument value. + Command []string + } + // SendCtrlC holds details about calls to the SendCtrlC method. + SendCtrlC []struct { + // Ctx is the ctx argument value. + Ctx context.Context + } + // SendEnter holds details about calls to the SendEnter method. + SendEnter []struct { + // Ctx is the ctx argument value. + Ctx context.Context + } + // SwitchToTab holds details about calls to the SwitchToTab method. + SwitchToTab []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Name is the name argument value. + Name string + } + // TabExists holds details about calls to the TabExists method. + TabExists []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Name is the name argument value. + Name string + } + // TerminateAndCloseTab holds details about calls to the TerminateAndCloseTab method. + TerminateAndCloseTab []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // TabName is the tabName argument value. + TabName string + } + // TerminateProcess holds details about calls to the TerminateProcess method. + TerminateProcess []struct { + // Ctx is the ctx argument value. + Ctx context.Context + } + // ToggleFloatingPanes holds details about calls to the ToggleFloatingPanes method. + ToggleFloatingPanes []struct { + // Ctx is the ctx argument value. + Ctx context.Context + } + // WriteASCII holds details about calls to the WriteASCII method. + WriteASCII []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Code is the code argument value. + Code int + } + // WriteChars holds details about calls to the WriteChars method. + WriteChars []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Text is the text argument value. + Text string + } + } + lockClearAndExecute sync.RWMutex + lockCloseTab sync.RWMutex + lockCreateTab sync.RWMutex + lockCreateTabWithCommand sync.RWMutex + lockExecuteCommand sync.RWMutex + lockQueryTabNames sync.RWMutex + lockRun sync.RWMutex + lockRunFloating sync.RWMutex + lockSendCtrlC sync.RWMutex + lockSendEnter sync.RWMutex + lockSwitchToTab sync.RWMutex + lockTabExists sync.RWMutex + lockTerminateAndCloseTab sync.RWMutex + lockTerminateProcess sync.RWMutex + lockToggleFloatingPanes sync.RWMutex + lockWriteASCII sync.RWMutex + lockWriteChars sync.RWMutex +} + +// ClearAndExecute calls ClearAndExecuteFunc. +func (mock *SessionMock) ClearAndExecute(ctx context.Context, cmd string) error { + callInfo := struct { + Ctx context.Context + Cmd string + }{ + Ctx: ctx, + Cmd: cmd, + } + mock.lockClearAndExecute.Lock() + mock.calls.ClearAndExecute = append(mock.calls.ClearAndExecute, callInfo) + mock.lockClearAndExecute.Unlock() + if mock.ClearAndExecuteFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.ClearAndExecuteFunc(ctx, cmd) +} + +// ClearAndExecuteCalls gets all the calls that were made to ClearAndExecute. +// Check the length with: +// +// len(mockedSession.ClearAndExecuteCalls()) +func (mock *SessionMock) ClearAndExecuteCalls() []struct { + Ctx context.Context + Cmd string +} { + var calls []struct { + Ctx context.Context + Cmd string + } + mock.lockClearAndExecute.RLock() + calls = mock.calls.ClearAndExecute + mock.lockClearAndExecute.RUnlock() + return calls +} + +// CloseTab calls CloseTabFunc. +func (mock *SessionMock) CloseTab(ctx context.Context) error { + callInfo := struct { + Ctx context.Context + }{ + Ctx: ctx, + } + mock.lockCloseTab.Lock() + mock.calls.CloseTab = append(mock.calls.CloseTab, callInfo) + mock.lockCloseTab.Unlock() + if mock.CloseTabFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.CloseTabFunc(ctx) +} + +// CloseTabCalls gets all the calls that were made to CloseTab. +// Check the length with: +// +// len(mockedSession.CloseTabCalls()) +func (mock *SessionMock) CloseTabCalls() []struct { + Ctx context.Context +} { + var calls []struct { + Ctx context.Context + } + mock.lockCloseTab.RLock() + calls = mock.calls.CloseTab + mock.lockCloseTab.RUnlock() + return calls +} + +// CreateTab calls CreateTabFunc. +func (mock *SessionMock) CreateTab(ctx context.Context, name string, cwd string) error { + callInfo := struct { + Ctx context.Context + Name string + Cwd string + }{ + Ctx: ctx, + Name: name, + Cwd: cwd, + } + mock.lockCreateTab.Lock() + mock.calls.CreateTab = append(mock.calls.CreateTab, callInfo) + mock.lockCreateTab.Unlock() + if mock.CreateTabFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.CreateTabFunc(ctx, name, cwd) +} + +// CreateTabCalls gets all the calls that were made to CreateTab. +// Check the length with: +// +// len(mockedSession.CreateTabCalls()) +func (mock *SessionMock) CreateTabCalls() []struct { + Ctx context.Context + Name string + Cwd string +} { + var calls []struct { + Ctx context.Context + Name string + Cwd string + } + mock.lockCreateTab.RLock() + calls = mock.calls.CreateTab + mock.lockCreateTab.RUnlock() + return calls +} + +// CreateTabWithCommand calls CreateTabWithCommandFunc. +func (mock *SessionMock) CreateTabWithCommand(ctx context.Context, name string, cwd string, command string, args []string, paneName string) error { + callInfo := struct { + Ctx context.Context + Name string + Cwd string + Command string + Args []string + PaneName string + }{ + Ctx: ctx, + Name: name, + Cwd: cwd, + Command: command, + Args: args, + PaneName: paneName, + } + mock.lockCreateTabWithCommand.Lock() + mock.calls.CreateTabWithCommand = append(mock.calls.CreateTabWithCommand, callInfo) + mock.lockCreateTabWithCommand.Unlock() + if mock.CreateTabWithCommandFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.CreateTabWithCommandFunc(ctx, name, cwd, command, args, paneName) +} + +// CreateTabWithCommandCalls gets all the calls that were made to CreateTabWithCommand. +// Check the length with: +// +// len(mockedSession.CreateTabWithCommandCalls()) +func (mock *SessionMock) CreateTabWithCommandCalls() []struct { + Ctx context.Context + Name string + Cwd string + Command string + Args []string + PaneName string +} { + var calls []struct { + Ctx context.Context + Name string + Cwd string + Command string + Args []string + PaneName string + } + mock.lockCreateTabWithCommand.RLock() + calls = mock.calls.CreateTabWithCommand + mock.lockCreateTabWithCommand.RUnlock() + return calls +} + +// ExecuteCommand calls ExecuteCommandFunc. +func (mock *SessionMock) ExecuteCommand(ctx context.Context, cmd string) error { + callInfo := struct { + Ctx context.Context + Cmd string + }{ + Ctx: ctx, + Cmd: cmd, + } + mock.lockExecuteCommand.Lock() + mock.calls.ExecuteCommand = append(mock.calls.ExecuteCommand, callInfo) + mock.lockExecuteCommand.Unlock() + if mock.ExecuteCommandFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.ExecuteCommandFunc(ctx, cmd) +} + +// ExecuteCommandCalls gets all the calls that were made to ExecuteCommand. +// Check the length with: +// +// len(mockedSession.ExecuteCommandCalls()) +func (mock *SessionMock) ExecuteCommandCalls() []struct { + Ctx context.Context + Cmd string +} { + var calls []struct { + Ctx context.Context + Cmd string + } + mock.lockExecuteCommand.RLock() + calls = mock.calls.ExecuteCommand + mock.lockExecuteCommand.RUnlock() + return calls +} + +// QueryTabNames calls QueryTabNamesFunc. +func (mock *SessionMock) QueryTabNames(ctx context.Context) ([]string, error) { + callInfo := struct { + Ctx context.Context + }{ + Ctx: ctx, + } + mock.lockQueryTabNames.Lock() + mock.calls.QueryTabNames = append(mock.calls.QueryTabNames, callInfo) + mock.lockQueryTabNames.Unlock() + if mock.QueryTabNamesFunc == nil { + var ( + stringsOut []string + errOut error + ) + return stringsOut, errOut + } + return mock.QueryTabNamesFunc(ctx) +} + +// QueryTabNamesCalls gets all the calls that were made to QueryTabNames. +// Check the length with: +// +// len(mockedSession.QueryTabNamesCalls()) +func (mock *SessionMock) QueryTabNamesCalls() []struct { + Ctx context.Context +} { + var calls []struct { + Ctx context.Context + } + mock.lockQueryTabNames.RLock() + calls = mock.calls.QueryTabNames + mock.lockQueryTabNames.RUnlock() + return calls +} + +// Run calls RunFunc. +func (mock *SessionMock) Run(ctx context.Context, name string, cwd string, command ...string) error { + callInfo := struct { + Ctx context.Context + Name string + Cwd string + Command []string + }{ + Ctx: ctx, + Name: name, + Cwd: cwd, + Command: command, + } + mock.lockRun.Lock() + mock.calls.Run = append(mock.calls.Run, callInfo) + mock.lockRun.Unlock() + if mock.RunFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.RunFunc(ctx, name, cwd, command...) +} + +// RunCalls gets all the calls that were made to Run. +// Check the length with: +// +// len(mockedSession.RunCalls()) +func (mock *SessionMock) RunCalls() []struct { + Ctx context.Context + Name string + Cwd string + Command []string +} { + var calls []struct { + Ctx context.Context + Name string + Cwd string + Command []string + } + mock.lockRun.RLock() + calls = mock.calls.Run + mock.lockRun.RUnlock() + return calls +} + +// RunFloating calls RunFloatingFunc. +func (mock *SessionMock) RunFloating(ctx context.Context, name string, cwd string, command ...string) error { + callInfo := struct { + Ctx context.Context + Name string + Cwd string + Command []string + }{ + Ctx: ctx, + Name: name, + Cwd: cwd, + Command: command, + } + mock.lockRunFloating.Lock() + mock.calls.RunFloating = append(mock.calls.RunFloating, callInfo) + mock.lockRunFloating.Unlock() + if mock.RunFloatingFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.RunFloatingFunc(ctx, name, cwd, command...) +} + +// RunFloatingCalls gets all the calls that were made to RunFloating. +// Check the length with: +// +// len(mockedSession.RunFloatingCalls()) +func (mock *SessionMock) RunFloatingCalls() []struct { + Ctx context.Context + Name string + Cwd string + Command []string +} { + var calls []struct { + Ctx context.Context + Name string + Cwd string + Command []string + } + mock.lockRunFloating.RLock() + calls = mock.calls.RunFloating + mock.lockRunFloating.RUnlock() + return calls +} + +// SendCtrlC calls SendCtrlCFunc. +func (mock *SessionMock) SendCtrlC(ctx context.Context) error { + callInfo := struct { + Ctx context.Context + }{ + Ctx: ctx, + } + mock.lockSendCtrlC.Lock() + mock.calls.SendCtrlC = append(mock.calls.SendCtrlC, callInfo) + mock.lockSendCtrlC.Unlock() + if mock.SendCtrlCFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.SendCtrlCFunc(ctx) +} + +// SendCtrlCCalls gets all the calls that were made to SendCtrlC. +// Check the length with: +// +// len(mockedSession.SendCtrlCCalls()) +func (mock *SessionMock) SendCtrlCCalls() []struct { + Ctx context.Context +} { + var calls []struct { + Ctx context.Context + } + mock.lockSendCtrlC.RLock() + calls = mock.calls.SendCtrlC + mock.lockSendCtrlC.RUnlock() + return calls +} + +// SendEnter calls SendEnterFunc. +func (mock *SessionMock) SendEnter(ctx context.Context) error { + callInfo := struct { + Ctx context.Context + }{ + Ctx: ctx, + } + mock.lockSendEnter.Lock() + mock.calls.SendEnter = append(mock.calls.SendEnter, callInfo) + mock.lockSendEnter.Unlock() + if mock.SendEnterFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.SendEnterFunc(ctx) +} + +// SendEnterCalls gets all the calls that were made to SendEnter. +// Check the length with: +// +// len(mockedSession.SendEnterCalls()) +func (mock *SessionMock) SendEnterCalls() []struct { + Ctx context.Context +} { + var calls []struct { + Ctx context.Context + } + mock.lockSendEnter.RLock() + calls = mock.calls.SendEnter + mock.lockSendEnter.RUnlock() + return calls +} + +// SwitchToTab calls SwitchToTabFunc. +func (mock *SessionMock) SwitchToTab(ctx context.Context, name string) error { + callInfo := struct { + Ctx context.Context + Name string + }{ + Ctx: ctx, + Name: name, + } + mock.lockSwitchToTab.Lock() + mock.calls.SwitchToTab = append(mock.calls.SwitchToTab, callInfo) + mock.lockSwitchToTab.Unlock() + if mock.SwitchToTabFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.SwitchToTabFunc(ctx, name) +} + +// SwitchToTabCalls gets all the calls that were made to SwitchToTab. +// Check the length with: +// +// len(mockedSession.SwitchToTabCalls()) +func (mock *SessionMock) SwitchToTabCalls() []struct { + Ctx context.Context + Name string +} { + var calls []struct { + Ctx context.Context + Name string + } + mock.lockSwitchToTab.RLock() + calls = mock.calls.SwitchToTab + mock.lockSwitchToTab.RUnlock() + return calls +} + +// TabExists calls TabExistsFunc. +func (mock *SessionMock) TabExists(ctx context.Context, name string) (bool, error) { + callInfo := struct { + Ctx context.Context + Name string + }{ + Ctx: ctx, + Name: name, + } + mock.lockTabExists.Lock() + mock.calls.TabExists = append(mock.calls.TabExists, callInfo) + mock.lockTabExists.Unlock() + if mock.TabExistsFunc == nil { + var ( + bOut bool + errOut error + ) + return bOut, errOut + } + return mock.TabExistsFunc(ctx, name) +} + +// TabExistsCalls gets all the calls that were made to TabExists. +// Check the length with: +// +// len(mockedSession.TabExistsCalls()) +func (mock *SessionMock) TabExistsCalls() []struct { + Ctx context.Context + Name string +} { + var calls []struct { + Ctx context.Context + Name string + } + mock.lockTabExists.RLock() + calls = mock.calls.TabExists + mock.lockTabExists.RUnlock() + return calls +} + +// TerminateAndCloseTab calls TerminateAndCloseTabFunc. +func (mock *SessionMock) TerminateAndCloseTab(ctx context.Context, tabName string) error { + callInfo := struct { + Ctx context.Context + TabName string + }{ + Ctx: ctx, + TabName: tabName, + } + mock.lockTerminateAndCloseTab.Lock() + mock.calls.TerminateAndCloseTab = append(mock.calls.TerminateAndCloseTab, callInfo) + mock.lockTerminateAndCloseTab.Unlock() + if mock.TerminateAndCloseTabFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.TerminateAndCloseTabFunc(ctx, tabName) +} + +// TerminateAndCloseTabCalls gets all the calls that were made to TerminateAndCloseTab. +// Check the length with: +// +// len(mockedSession.TerminateAndCloseTabCalls()) +func (mock *SessionMock) TerminateAndCloseTabCalls() []struct { + Ctx context.Context + TabName string +} { + var calls []struct { + Ctx context.Context + TabName string + } + mock.lockTerminateAndCloseTab.RLock() + calls = mock.calls.TerminateAndCloseTab + mock.lockTerminateAndCloseTab.RUnlock() + return calls +} + +// TerminateProcess calls TerminateProcessFunc. +func (mock *SessionMock) TerminateProcess(ctx context.Context) error { + callInfo := struct { + Ctx context.Context + }{ + Ctx: ctx, + } + mock.lockTerminateProcess.Lock() + mock.calls.TerminateProcess = append(mock.calls.TerminateProcess, callInfo) + mock.lockTerminateProcess.Unlock() + if mock.TerminateProcessFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.TerminateProcessFunc(ctx) +} + +// TerminateProcessCalls gets all the calls that were made to TerminateProcess. +// Check the length with: +// +// len(mockedSession.TerminateProcessCalls()) +func (mock *SessionMock) TerminateProcessCalls() []struct { + Ctx context.Context +} { + var calls []struct { + Ctx context.Context + } + mock.lockTerminateProcess.RLock() + calls = mock.calls.TerminateProcess + mock.lockTerminateProcess.RUnlock() + return calls +} + +// ToggleFloatingPanes calls ToggleFloatingPanesFunc. +func (mock *SessionMock) ToggleFloatingPanes(ctx context.Context) error { + callInfo := struct { + Ctx context.Context + }{ + Ctx: ctx, + } + mock.lockToggleFloatingPanes.Lock() + mock.calls.ToggleFloatingPanes = append(mock.calls.ToggleFloatingPanes, callInfo) + mock.lockToggleFloatingPanes.Unlock() + if mock.ToggleFloatingPanesFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.ToggleFloatingPanesFunc(ctx) +} + +// ToggleFloatingPanesCalls gets all the calls that were made to ToggleFloatingPanes. +// Check the length with: +// +// len(mockedSession.ToggleFloatingPanesCalls()) +func (mock *SessionMock) ToggleFloatingPanesCalls() []struct { + Ctx context.Context +} { + var calls []struct { + Ctx context.Context + } + mock.lockToggleFloatingPanes.RLock() + calls = mock.calls.ToggleFloatingPanes + mock.lockToggleFloatingPanes.RUnlock() + return calls +} + +// WriteASCII calls WriteASCIIFunc. +func (mock *SessionMock) WriteASCII(ctx context.Context, code int) error { + callInfo := struct { + Ctx context.Context + Code int + }{ + Ctx: ctx, + Code: code, + } + mock.lockWriteASCII.Lock() + mock.calls.WriteASCII = append(mock.calls.WriteASCII, callInfo) + mock.lockWriteASCII.Unlock() + if mock.WriteASCIIFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.WriteASCIIFunc(ctx, code) +} + +// WriteASCIICalls gets all the calls that were made to WriteASCII. +// Check the length with: +// +// len(mockedSession.WriteASCIICalls()) +func (mock *SessionMock) WriteASCIICalls() []struct { + Ctx context.Context + Code int +} { + var calls []struct { + Ctx context.Context + Code int + } + mock.lockWriteASCII.RLock() + calls = mock.calls.WriteASCII + mock.lockWriteASCII.RUnlock() + return calls +} + +// WriteChars calls WriteCharsFunc. +func (mock *SessionMock) WriteChars(ctx context.Context, text string) error { + callInfo := struct { + Ctx context.Context + Text string + }{ + Ctx: ctx, + Text: text, + } + mock.lockWriteChars.Lock() + mock.calls.WriteChars = append(mock.calls.WriteChars, callInfo) + mock.lockWriteChars.Unlock() + if mock.WriteCharsFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.WriteCharsFunc(ctx, text) +} + +// WriteCharsCalls gets all the calls that were made to WriteChars. +// Check the length with: +// +// len(mockedSession.WriteCharsCalls()) +func (mock *SessionMock) WriteCharsCalls() []struct { + Ctx context.Context + Text string +} { + var calls []struct { + Ctx context.Context + Text string + } + mock.lockWriteChars.RLock() + calls = mock.calls.WriteChars + mock.lockWriteChars.RUnlock() + return calls +} diff --git a/internal/zellij/zellij_test.go b/internal/zellij/zellij_test.go index db023c5c..2929d14c 100644 --- a/internal/zellij/zellij_test.go +++ b/internal/zellij/zellij_test.go @@ -3,37 +3,25 @@ package zellij import ( "testing" "time" + + "github.com/stretchr/testify/require" ) func TestNew(t *testing.T) { client := New() - if client == nil { - t.Fatal("New() returned nil") - } + require.NotNil(t, client, "New() returned nil") // Check default values - if client.TabCreateDelay != 500*time.Millisecond { - t.Errorf("TabCreateDelay = %v, want %v", client.TabCreateDelay, 500*time.Millisecond) - } - if client.CtrlCDelay != 500*time.Millisecond { - t.Errorf("CtrlCDelay = %v, want %v", client.CtrlCDelay, 500*time.Millisecond) - } - if client.CommandDelay != 100*time.Millisecond { - t.Errorf("CommandDelay = %v, want %v", client.CommandDelay, 100*time.Millisecond) - } - if client.SessionStartWait != 1*time.Second { - t.Errorf("SessionStartWait = %v, want %v", client.SessionStartWait, 1*time.Second) - } + require.Equal(t, 500*time.Millisecond, client.TabCreateDelay) + require.Equal(t, 500*time.Millisecond, client.CtrlCDelay) + require.Equal(t, 100*time.Millisecond, client.CommandDelay) + require.Equal(t, 1*time.Second, client.SessionStartWait) } func TestASCIIConstants(t *testing.T) { // Verify ASCII constants are correct - if ASCIICtrlC != 3 { - t.Errorf("ASCIICtrlC = %d, want 3", ASCIICtrlC) - } - if ASCIIEnter != 13 { - t.Errorf("ASCIIEnter = %d, want 13", ASCIIEnter) - } + require.Equal(t, 3, ASCIICtrlC) + require.Equal(t, 13, ASCIIEnter) } func TestClientConfiguration(t *testing.T) { @@ -45,16 +33,8 @@ func TestClientConfiguration(t *testing.T) { client.CommandDelay = 50 * time.Millisecond client.SessionStartWait = 2 * time.Second - if client.TabCreateDelay != 1*time.Second { - t.Errorf("TabCreateDelay not updated correctly") - } - if client.CtrlCDelay != 250*time.Millisecond { - t.Errorf("CtrlCDelay not updated correctly") - } - if client.CommandDelay != 50*time.Millisecond { - t.Errorf("CommandDelay not updated correctly") - } - if client.SessionStartWait != 2*time.Second { - t.Errorf("SessionStartWait not updated correctly") - } + require.Equal(t, 1*time.Second, client.TabCreateDelay, "TabCreateDelay not updated correctly") + require.Equal(t, 250*time.Millisecond, client.CtrlCDelay, "CtrlCDelay not updated correctly") + require.Equal(t, 50*time.Millisecond, client.CommandDelay, "CommandDelay not updated correctly") + require.Equal(t, 2*time.Second, client.SessionStartWait, "SessionStartWait not updated correctly") } diff --git a/mise.toml b/mise.toml index 470f23e7..bcfb523c 100644 --- a/mise.toml +++ b/mise.toml @@ -6,6 +6,7 @@ go = "1.25" "aqua:golangci/golangci-lint" = "v2.8.0" "aqua:evilmartians/lefthook" = "latest" "aqua:mikefarah/yq" = "latest" +"go:github.com/matryer/moq" = "latest" [env] GOPATH = "{{env.HOME}}/go" @@ -25,3 +26,7 @@ run = "sqlc generate" [tasks.lint] description = "Run golangci-lint" run = "golangci-lint run ./..." + +[tasks.generate] +description = "Run go generate for all packages" +run = "go generate ./..."