Skip to content

Commit

Permalink
fix getting pull requests that come from forks (#563)
Browse files Browse the repository at this point in the history
* fix getting pull requests that come from forks

* fix support for PRs from forks and reduce total requests

* format

* refactor pull_requests and logs

---------

Co-authored-by: alanpatel <[email protected]>
  • Loading branch information
alankpatel and alankpatel authored Nov 12, 2024
1 parent ba3bedb commit 12e1e81
Show file tree
Hide file tree
Showing 25 changed files with 5,379 additions and 31 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ require (
github.com/shurcooL/githubv4 v0.0.0-20240727222349-48295856cce7 // indirect
github.com/shurcooL/graphql v0.0.0-20181231061246-d48a9a75455f // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/stretchr/objx v0.5.2 // indirect
golang.org/x/net v0.31.0 // indirect
golang.org/x/oauth2 v0.24.0 // indirect
golang.org/x/sys v0.27.0 // indirect
Expand Down
2 changes: 1 addition & 1 deletion pull/github_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ func (ghc *GithubContext) Labels(ctx context.Context) ([]string, error) {
func (ghc *GithubContext) IsTargeted(ctx context.Context) (bool, error) {
ref := fmt.Sprintf("refs/heads/%s", ghc.pr.GetHead().GetRef())

prs, err := ListOpenPullRequestsForRef(ctx, ghc.client, ghc.owner, ghc.repo, ref)
prs, err := GetAllOpenPullRequestsForRef(ctx, ghc.client.PullRequests, ghc.owner, ghc.repo, ref)
if err != nil {
return false, errors.Wrap(err, "failed to determine targeted status")
}
Expand Down
116 changes: 90 additions & 26 deletions pull/pull_requests.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2018 Palantir Technologies, Inc.
// Copyright 2024 Palantir Technologies, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -23,57 +23,121 @@ import (
"github.com/rs/zerolog"
)

// ListOpenPullRequestsForSHA returns all pull requests where the HEAD of the source branch
// in the pull request matches the given SHA.
func ListOpenPullRequestsForSHA(ctx context.Context, client *github.Client, owner, repoName, SHA string) ([]*github.PullRequest, error) {
prs, _, err := client.PullRequests.ListPullRequestsWithCommit(ctx, owner, repoName, SHA, &github.ListOptions{
// In practice, there should be at most 1-3 PRs for a given commit. In
// exceptional cases, if there are more than 100 PRs, we'll only
// consider the first 100 to avoid paging.
PerPage: 100,
})
if err != nil {
return nil, errors.Wrapf(err, "failed to list pull requests for repository %s/%s", owner, repoName)
}
// GitHubPullRequestClient is an interface that wraps the methods used from the github.Client.
type GitHubPullRequestClient interface {
ListPullRequestsWithCommit(ctx context.Context, owner, repo, sha string, opts *github.ListOptions) ([]*github.PullRequest, *github.Response, error)
List(ctx context.Context, owner, repo string, opts *github.PullRequestListOptions) ([]*github.PullRequest, *github.Response, error)
}

// getOpenPullRequestsForSHA returns all open pull requests where the HEAD of the source branch
// matches the given SHA.
func getOpenPullRequestsForSHA(ctx context.Context, client GitHubPullRequestClient, owner, repo, sha string) ([]*github.PullRequest, error) {
logger := zerolog.Ctx(ctx)
var results []*github.PullRequest
for _, pr := range prs {
if pr.GetState() == "open" && pr.GetHead().GetSHA() == SHA {
results = append(results, pr)
opts := &github.ListOptions{PerPage: 100}

for {
prs, resp, err := client.ListPullRequestsWithCommit(ctx, owner, repo, sha, opts)
if err != nil {
return nil, errors.Wrapf(err, "failed to list pull requests for repository %s/%s", owner, repo)
}

for _, pr := range prs {
if pr.GetState() == "open" && pr.GetHead().GetSHA() == sha {
logger.Debug().Msgf("found open pull request with sha %s", pr.GetHead().GetSHA())
results = append(results, pr)
}
}

if resp.NextPage == 0 {
break
}
opts.Page = resp.NextPage
}

return results, nil
}

func ListOpenPullRequestsForRef(ctx context.Context, client *github.Client, owner, repoName, ref string) ([]*github.PullRequest, error) {
// ListAllOpenPullRequestsFilteredBySHA returns all open pull requests where the HEAD of the source branch
// matches the given SHA by fetching all open PRs and filtering.
func ListAllOpenPullRequestsFilteredBySHA(ctx context.Context, client GitHubPullRequestClient, owner, repo, sha string) ([]*github.PullRequest, error) {
logger := zerolog.Ctx(ctx)
var results []*github.PullRequest
opts := &github.PullRequestListOptions{
State: "open",
ListOptions: github.ListOptions{PerPage: 100},
}

for {
prs, resp, err := client.List(ctx, owner, repo, opts)
if err != nil {
return nil, errors.Wrapf(err, "failed to list pull requests for repository %s/%s", owner, repo)
}

for _, pr := range prs {
if pr.Head.GetSHA() == sha {
logger.Debug().Msgf("found open pull request with sha %s", pr.Head.GetSHA())
results = append(results, pr)
}
}

if resp.NextPage == 0 {
break
}
opts.Page = resp.NextPage
}

return results, nil
}

// GetAllPossibleOpenPullRequestsForSHA attempts to find all open pull requests
// associated with the given SHA using multiple methods in case we are dealing with a fork
func GetAllPossibleOpenPullRequestsForSHA(ctx context.Context, client GitHubPullRequestClient, owner, repo, sha string) ([]*github.PullRequest, error) {
logger := zerolog.Ctx(ctx)

ref = strings.TrimPrefix(ref, "refs/heads/")
prs, err := getOpenPullRequestsForSHA(ctx, client, owner, repo, sha)
if err != nil {
return nil, errors.Wrap(err, "failed to get open pull requests matching the SHA")
}

if len(prs) == 0 {
logger.Debug().Msg("no pull requests found via commit association , searching all pull requests by SHA")
prs, err = ListAllOpenPullRequestsFilteredBySHA(ctx, client, owner, repo, sha)
if err != nil {
return nil, errors.Wrap(err, "failed to list open pull requests matching the SHA")
}
}

return prs, nil
}

// GetAllOpenPullRequestsForRef returns all open pull requests for a given base branch reference.
func GetAllOpenPullRequestsForRef(ctx context.Context, client GitHubPullRequestClient, owner, repo, ref string) ([]*github.PullRequest, error) {
logger := zerolog.Ctx(ctx)
ref = strings.TrimPrefix(ref, "refs/heads/")
opts := &github.PullRequestListOptions{
State: "open",
Base: ref, // Filter by base branch name
ListOptions: github.ListOptions{
PerPage: 100,
},
State: "open",
Base: ref,
ListOptions: github.ListOptions{PerPage: 100},
}

var results []*github.PullRequest
for {
prs, resp, err := client.PullRequests.List(ctx, owner, repoName, opts)
prs, resp, err := client.List(ctx, owner, repo, opts)
if err != nil {
return nil, errors.Wrapf(err, "failed to list pull requests for repository %s/%s", owner, repoName)
return nil, errors.Wrapf(err, "failed to list pull requests for repository %s/%s", owner, repo)
}

for _, pr := range prs {
logger.Debug().Msgf("found open pull request with base ref %s", pr.GetBase().GetRef())
results = append(results, pr)
}

if resp.NextPage == 0 {
break
}
opts.Page = resp.NextPage
}

return results, nil

}
193 changes: 193 additions & 0 deletions pull/pull_requests_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
// Copyright 2024 Palantir Technologies, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// pull_test.go

package pull

import (
"context"
"testing"

"github.com/google/go-github/v66/github"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)

type MockGitHubClient struct {
mock.Mock
}

func (m *MockGitHubClient) ListPullRequestsWithCommit(ctx context.Context, owner, repo, sha string, opts *github.ListOptions) ([]*github.PullRequest, *github.Response, error) {
args := m.Called(ctx, owner, repo, sha, opts)
return args.Get(0).([]*github.PullRequest), args.Get(1).(*github.Response), args.Error(2)
}

func (m *MockGitHubClient) List(ctx context.Context, owner, repo string, opts *github.PullRequestListOptions) ([]*github.PullRequest, *github.Response, error) {
args := m.Called(ctx, owner, repo, opts)
return args.Get(0).([]*github.PullRequest), args.Get(1).(*github.Response), args.Error(2)
}

func TestGetOpenPullRequestsForSHA(t *testing.T) {
mockClient := new(MockGitHubClient)
ctx := context.Background()
owner := "owner"
repo := "repo"
sha := "sha"

pr := &github.PullRequest{
State: github.String("open"),
Head: &github.PullRequestBranch{SHA: github.String(sha)},
}

mockClient.On("ListPullRequestsWithCommit", ctx, owner, repo, sha, mock.Anything).Return([]*github.PullRequest{pr}, &github.Response{NextPage: 0}, nil)

prs, err := getOpenPullRequestsForSHA(ctx, mockClient, owner, repo, sha)
assert.NoError(t, err)
assert.Len(t, prs, 1)
assert.Equal(t, sha, prs[0].GetHead().GetSHA())

mockClient.AssertExpectations(t)
}

func TestListOpenPullRequestsForSHA(t *testing.T) {
mockClient := new(MockGitHubClient)
ctx := context.Background()
owner := "owner"
repo := "repo"
sha := "sha"

pr := &github.PullRequest{
State: github.String("open"),
Head: &github.PullRequestBranch{SHA: github.String(sha)},
}

mockClient.On("List", ctx, owner, repo, mock.Anything).Return([]*github.PullRequest{pr}, &github.Response{NextPage: 0}, nil)

prs, err := ListAllOpenPullRequestsFilteredBySHA(ctx, mockClient, owner, repo, sha)
assert.NoError(t, err)
assert.Len(t, prs, 1)
assert.Equal(t, sha, prs[0].GetHead().GetSHA())

mockClient.AssertExpectations(t)
}

func TestGetAllPossibleOpenPullRequestsForSHA_FirstMethodReturnsResults(t *testing.T) {
mockClient := new(MockGitHubClient)
ctx := context.Background()
owner := "owner"
repo := "repo"
sha := "sha"

pr := &github.PullRequest{
State: github.String("open"),
Head: &github.PullRequestBranch{SHA: github.String(sha)},
}

// Mock the first method to return a valid pull request.
mockClient.On("ListPullRequestsWithCommit", ctx, owner, repo, sha, mock.Anything).Return([]*github.PullRequest{pr}, &github.Response{NextPage: 0}, nil).Once()
// Mock the second method to not be called.
mockClient.On("List", ctx, owner, repo, mock.Anything).Return(nil, nil, nil).Maybe()

prs, err := GetAllPossibleOpenPullRequestsForSHA(ctx, mockClient, owner, repo, sha)
assert.NoError(t, err)
assert.Len(t, prs, 1)
assert.Equal(t, sha, prs[0].GetHead().GetSHA())

mockClient.AssertExpectations(t)
}

func TestGetAllPossibleOpenPullRequestsForSHA_SecondMethodReturnsResults(t *testing.T) {
mockClient := new(MockGitHubClient)
ctx := context.Background()
owner := "owner"
repo := "repo"
sha := "sha"

pr := &github.PullRequest{
State: github.String("open"),
Head: &github.PullRequestBranch{SHA: github.String(sha)},
}

// Mock the first method to return no results.
mockClient.On("ListPullRequestsWithCommit", ctx, owner, repo, sha, mock.Anything).Return([]*github.PullRequest{}, &github.Response{NextPage: 0}, nil).Once()
// Mock the second method to return a valid pull request.
mockClient.On("List", ctx, owner, repo, mock.Anything).Return([]*github.PullRequest{pr}, &github.Response{NextPage: 0}, nil).Once()

prs, err := GetAllPossibleOpenPullRequestsForSHA(ctx, mockClient, owner, repo, sha)
assert.NoError(t, err)
assert.Len(t, prs, 1)
assert.Equal(t, sha, prs[0].GetHead().GetSHA())

mockClient.AssertExpectations(t)
}

func TestGetAllPossibleOpenPullRequestsForSHA_NoResults(t *testing.T) {
mockClient := new(MockGitHubClient)
ctx := context.Background()
owner := "owner"
repo := "repo"
sha := "sha"

// Mock both methods to return no results.
mockClient.On("ListPullRequestsWithCommit", ctx, owner, repo, sha, mock.Anything).Return([]*github.PullRequest{}, &github.Response{NextPage: 0}, nil).Once()
mockClient.On("List", ctx, owner, repo, mock.Anything).Return([]*github.PullRequest{}, &github.Response{NextPage: 0}, nil).Once()

prs, err := GetAllPossibleOpenPullRequestsForSHA(ctx, mockClient, owner, repo, sha)
assert.NoError(t, err)
assert.Len(t, prs, 0)

mockClient.AssertExpectations(t)
}

func TestGetAllPossibleOpenPullRequestsForSHA_Errors(t *testing.T) {
mockClient := new(MockGitHubClient)
ctx := context.Background()
owner := "owner"
repo := "repo"
sha := "sha"

// Mock the first method to return an error.
mockClient.On("ListPullRequestsWithCommit", ctx, owner, repo, sha, mock.Anything).Return([]*github.PullRequest{}, &github.Response{}, assert.AnError).Once()
// Mock the second method to not be called.
mockClient.On("List", ctx, owner, repo, mock.Anything).Return(nil, nil, nil).Maybe()

prs, err := GetAllPossibleOpenPullRequestsForSHA(ctx, mockClient, owner, repo, sha)
assert.Error(t, err)
assert.Nil(t, prs)

mockClient.AssertExpectations(t)
}

func TestListOpenPullRequestsForRef(t *testing.T) {
mockClient := new(MockGitHubClient)
ctx := context.Background()
owner := "owner"
repo := "repo"
ref := "refs/heads/main"

pr := &github.PullRequest{
State: github.String("open"),
Base: &github.PullRequestBranch{Ref: github.String("main")},
}

mockClient.On("List", ctx, owner, repo, mock.Anything).Return([]*github.PullRequest{pr}, &github.Response{NextPage: 0}, nil)

prs, err := GetAllOpenPullRequestsForRef(ctx, mockClient, owner, repo, ref)
assert.NoError(t, err)
assert.Len(t, prs, 1)
assert.Equal(t, "main", prs[0].GetBase().GetRef())

mockClient.AssertExpectations(t)
}
15 changes: 13 additions & 2 deletions server/handler/check_run.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ func (h *CheckRun) Handle(ctx context.Context, eventType, deliveryID string, pay
}

repo := event.GetRepo()
owner := repo.GetOwner().GetLogin()
repoName := repo.GetName()

installationID := githubapp.GetInstallationIDFromEvent(&event)

ctx, logger := githubapp.PrepareRepoContext(ctx, installationID, repo)
Expand All @@ -56,8 +59,16 @@ func (h *CheckRun) Handle(ctx context.Context, eventType, deliveryID string, pay

prs := event.GetCheckRun().PullRequests
if len(prs) == 0 {
logger.Debug().Msg("Doing nothing since status change event affects no open pull requests")
return nil
logger.Debug().Msg("No pull requests associated with the check run, searching by SHA")
// check runs on fork PRs do not have the PRs attached to the event so we need to filter all PRs by SHA
prs, err = pull.ListAllOpenPullRequestsFilteredBySHA(ctx, client.PullRequests, owner, repoName, event.GetCheckRun().GetHeadSHA())
if err != nil {
return errors.Wrap(err, "failed to determine open pull requests matching the status context change")
}
if len(prs) == 0 {
logger.Debug().Msg("No open pull requests found for the given SHA")
return nil
}
}

for _, pr := range prs {
Expand Down
Loading

0 comments on commit 12e1e81

Please sign in to comment.