diff --git a/pull/github_context.go b/pull/github_context.go index fff2c1e5e..7a7f9f4b6 100644 --- a/pull/github_context.go +++ b/pull/github_context.go @@ -282,7 +282,7 @@ func (ghc *GithubContext) Labels(ctx context.Context) ([]string, error) { func (ghc *GithubContext) IsTargeted(ctx context.Context) (bool, error) { ref := fmt.Sprintf("refs/heads/%s", ghc.pr.GetHead().GetRef()) - prs, err := ListOpenPullRequestsForRef(ctx, ghc.client.PullRequests, ghc.owner, ghc.repo, ref) + prs, err := GetAllOpenPullRequestsForRef(ctx, ghc.client.PullRequests, ghc.owner, ghc.repo, ref) if err != nil { return false, errors.Wrap(err, "failed to determine targeted status") } diff --git a/pull/pull_requests.go b/pull/pull_requests.go index cf23cb029..9023e0a4b 100644 --- a/pull/pull_requests.go +++ b/pull/pull_requests.go @@ -23,15 +23,15 @@ import ( "github.com/rs/zerolog" ) -// GitHubClient is an interface that wraps the methods used from the github.Client. -type GitHubClient interface { +// GitHubPullRequestClient is an interface that wraps the methods used from the github.Client. +type GitHubPullRequestClient interface { ListPullRequestsWithCommit(ctx context.Context, owner, repo, sha string, opts *github.ListOptions) ([]*github.PullRequest, *github.Response, error) List(ctx context.Context, owner, repo string, opts *github.PullRequestListOptions) ([]*github.PullRequest, *github.Response, error) } -// GetOpenPullRequestsForSHA returns all open pull requests where the HEAD of the source branch +// getOpenPullRequestsForSHA returns all open pull requests where the HEAD of the source branch // matches the given SHA. -func GetOpenPullRequestsForSHA(ctx context.Context, client GitHubClient, owner, repo, sha string) ([]*github.PullRequest, error) { +func getOpenPullRequestsForSHA(ctx context.Context, client GitHubPullRequestClient, owner, repo, sha string) ([]*github.PullRequest, error) { logger := zerolog.Ctx(ctx) var results []*github.PullRequest opts := &github.ListOptions{PerPage: 100} @@ -58,9 +58,9 @@ func GetOpenPullRequestsForSHA(ctx context.Context, client GitHubClient, owner, return results, nil } -// ListOpenPullRequestsForSHA returns all open pull requests where the HEAD of the source branch +// ListAllOpenPullRequestsFilteredBySHA returns all open pull requests where the HEAD of the source branch // matches the given SHA by fetching all open PRs and filtering. -func ListOpenPullRequestsForSHA(ctx context.Context, client GitHubClient, owner, repo, sha string) ([]*github.PullRequest, error) { +func ListAllOpenPullRequestsFilteredBySHA(ctx context.Context, client GitHubPullRequestClient, owner, repo, sha string) ([]*github.PullRequest, error) { logger := zerolog.Ctx(ctx) var results []*github.PullRequest opts := &github.PullRequestListOptions{ @@ -92,31 +92,27 @@ func ListOpenPullRequestsForSHA(ctx context.Context, client GitHubClient, owner, // GetAllPossibleOpenPullRequestsForSHA attempts to find all open pull requests // associated with the given SHA using multiple methods in case we are dealing with a fork -func GetAllPossibleOpenPullRequestsForSHA(ctx context.Context, client GitHubClient, owner, repo, sha string) ([]*github.PullRequest, error) { +func GetAllPossibleOpenPullRequestsForSHA(ctx context.Context, client GitHubPullRequestClient, owner, repo, sha string) ([]*github.PullRequest, error) { logger := zerolog.Ctx(ctx) - prs, err := GetOpenPullRequestsForSHA(ctx, client, owner, repo, sha) + prs, err := getOpenPullRequestsForSHA(ctx, client, owner, repo, sha) if err != nil { return nil, errors.Wrap(err, "failed to get open pull requests matching the SHA") } if len(prs) == 0 { - logger.Debug().Msg("No pull requests associated with the check run, searching by SHA") - prs, err = ListOpenPullRequestsForSHA(ctx, client, owner, repo, sha) + logger.Debug().Msg("no pull requests found via commit association , searching all pull requests by SHA") + prs, err = ListAllOpenPullRequestsFilteredBySHA(ctx, client, owner, repo, sha) if err != nil { return nil, errors.Wrap(err, "failed to list open pull requests matching the SHA") } - if len(prs) == 0 { - logger.Debug().Msg("No open pull requests found for the given SHA") - return nil, nil - } } return prs, nil } -// ListOpenPullRequestsForRef returns all open pull requests for a given base branch reference. -func ListOpenPullRequestsForRef(ctx context.Context, client GitHubClient, owner, repo, ref string) ([]*github.PullRequest, error) { +// GetAllOpenPullRequestsForRef returns all open pull requests for a given base branch reference. +func GetAllOpenPullRequestsForRef(ctx context.Context, client GitHubPullRequestClient, owner, repo, ref string) ([]*github.PullRequest, error) { logger := zerolog.Ctx(ctx) ref = strings.TrimPrefix(ref, "refs/heads/") opts := &github.PullRequestListOptions{ diff --git a/pull/pull_requests_test.go b/pull/pull_requests_test.go index fbd42f11c..883117e7e 100644 --- a/pull/pull_requests_test.go +++ b/pull/pull_requests_test.go @@ -53,7 +53,7 @@ func TestGetOpenPullRequestsForSHA(t *testing.T) { mockClient.On("ListPullRequestsWithCommit", ctx, owner, repo, sha, mock.Anything).Return([]*github.PullRequest{pr}, &github.Response{NextPage: 0}, nil) - prs, err := GetOpenPullRequestsForSHA(ctx, mockClient, owner, repo, sha) + prs, err := getOpenPullRequestsForSHA(ctx, mockClient, owner, repo, sha) assert.NoError(t, err) assert.Len(t, prs, 1) assert.Equal(t, sha, prs[0].GetHead().GetSHA()) @@ -75,7 +75,7 @@ func TestListOpenPullRequestsForSHA(t *testing.T) { mockClient.On("List", ctx, owner, repo, mock.Anything).Return([]*github.PullRequest{pr}, &github.Response{NextPage: 0}, nil) - prs, err := ListOpenPullRequestsForSHA(ctx, mockClient, owner, repo, sha) + prs, err := ListAllOpenPullRequestsFilteredBySHA(ctx, mockClient, owner, repo, sha) assert.NoError(t, err) assert.Len(t, prs, 1) assert.Equal(t, sha, prs[0].GetHead().GetSHA()) @@ -184,7 +184,7 @@ func TestListOpenPullRequestsForRef(t *testing.T) { mockClient.On("List", ctx, owner, repo, mock.Anything).Return([]*github.PullRequest{pr}, &github.Response{NextPage: 0}, nil) - prs, err := ListOpenPullRequestsForRef(ctx, mockClient, owner, repo, ref) + prs, err := GetAllOpenPullRequestsForRef(ctx, mockClient, owner, repo, ref) assert.NoError(t, err) assert.Len(t, prs, 1) assert.Equal(t, "main", prs[0].GetBase().GetRef()) diff --git a/server/handler/check_run.go b/server/handler/check_run.go index ac05332e8..1926a64fe 100644 --- a/server/handler/check_run.go +++ b/server/handler/check_run.go @@ -60,8 +60,8 @@ func (h *CheckRun) Handle(ctx context.Context, eventType, deliveryID string, pay prs := event.GetCheckRun().PullRequests if len(prs) == 0 { logger.Debug().Msg("No pull requests associated with the check run, searching by SHA") - // if no PR's were attached with the event let's check with Github in case it is a fork - prs, err = pull.ListOpenPullRequestsForSHA(ctx, client.PullRequests, owner, repoName, event.GetCheckRun().GetHeadSHA()) + // check runs on fork PRs do not have the PRs attached to the event so we need to filter all PRs by SHA + prs, err = pull.ListAllOpenPullRequestsFilteredBySHA(ctx, client.PullRequests, owner, repoName, event.GetCheckRun().GetHeadSHA()) if err != nil { return errors.Wrap(err, "failed to determine open pull requests matching the status context change") } diff --git a/server/handler/push.go b/server/handler/push.go index 1e566baf6..17c3d38e9 100644 --- a/server/handler/push.go +++ b/server/handler/push.go @@ -66,7 +66,7 @@ func (h *Push) Handle(ctx context.Context, eventType, deliveryID string, payload return errors.Wrap(err, "failed to instantiate github client") } - prs, err := pull.ListOpenPullRequestsForRef(ctx, client.PullRequests, owner, repoName, baseRef) + prs, err := pull.GetAllOpenPullRequestsForRef(ctx, client.PullRequests, owner, repoName, baseRef) if err != nil { return errors.Wrap(err, "failed to determine open pull requests matching the push change") }