Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix getting pull requests that come from forks #563

Merged
merged 5 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ require (
github.com/shurcooL/githubv4 v0.0.0-20240727222349-48295856cce7 // indirect
github.com/shurcooL/graphql v0.0.0-20181231061246-d48a9a75455f // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/stretchr/objx v0.5.2 // indirect
golang.org/x/net v0.31.0 // indirect
golang.org/x/oauth2 v0.24.0 // indirect
golang.org/x/sys v0.27.0 // indirect
Expand Down
2 changes: 1 addition & 1 deletion pull/github_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ func (ghc *GithubContext) Labels(ctx context.Context) ([]string, error) {
func (ghc *GithubContext) IsTargeted(ctx context.Context) (bool, error) {
ref := fmt.Sprintf("refs/heads/%s", ghc.pr.GetHead().GetRef())

prs, err := ListOpenPullRequestsForRef(ctx, ghc.client, ghc.owner, ghc.repo, ref)
prs, err := GetAllOpenPullRequestsForRef(ctx, ghc.client.PullRequests, ghc.owner, ghc.repo, ref)
if err != nil {
return false, errors.Wrap(err, "failed to determine targeted status")
}
Expand Down
116 changes: 90 additions & 26 deletions pull/pull_requests.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2018 Palantir Technologies, Inc.
// Copyright 2024 Palantir Technologies, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -23,57 +23,121 @@ import (
"github.com/rs/zerolog"
)

// ListOpenPullRequestsForSHA returns all pull requests where the HEAD of the source branch
// in the pull request matches the given SHA.
func ListOpenPullRequestsForSHA(ctx context.Context, client *github.Client, owner, repoName, SHA string) ([]*github.PullRequest, error) {
prs, _, err := client.PullRequests.ListPullRequestsWithCommit(ctx, owner, repoName, SHA, &github.ListOptions{
// In practice, there should be at most 1-3 PRs for a given commit. In
// exceptional cases, if there are more than 100 PRs, we'll only
// consider the first 100 to avoid paging.
PerPage: 100,
})
if err != nil {
return nil, errors.Wrapf(err, "failed to list pull requests for repository %s/%s", owner, repoName)
}
// GitHubPullRequestClient is an interface that wraps the methods used from the github.Client.
type GitHubPullRequestClient interface {
ListPullRequestsWithCommit(ctx context.Context, owner, repo, sha string, opts *github.ListOptions) ([]*github.PullRequest, *github.Response, error)
List(ctx context.Context, owner, repo string, opts *github.PullRequestListOptions) ([]*github.PullRequest, *github.Response, error)
}

// getOpenPullRequestsForSHA returns all open pull requests where the HEAD of the source branch
// matches the given SHA.
func getOpenPullRequestsForSHA(ctx context.Context, client GitHubPullRequestClient, owner, repo, sha string) ([]*github.PullRequest, error) {
logger := zerolog.Ctx(ctx)
var results []*github.PullRequest
for _, pr := range prs {
if pr.GetState() == "open" && pr.GetHead().GetSHA() == SHA {
results = append(results, pr)
opts := &github.ListOptions{PerPage: 100}

for {
prs, resp, err := client.ListPullRequestsWithCommit(ctx, owner, repo, sha, opts)
if err != nil {
return nil, errors.Wrapf(err, "failed to list pull requests for repository %s/%s", owner, repo)
}

for _, pr := range prs {
if pr.GetState() == "open" && pr.GetHead().GetSHA() == sha {
logger.Debug().Msgf("found open pull request with sha %s", pr.GetHead().GetSHA())
results = append(results, pr)
}
}

if resp.NextPage == 0 {
break
}
opts.Page = resp.NextPage
}

return results, nil
}

func ListOpenPullRequestsForRef(ctx context.Context, client *github.Client, owner, repoName, ref string) ([]*github.PullRequest, error) {
// ListAllOpenPullRequestsFilteredBySHA returns all open pull requests where the HEAD of the source branch
// matches the given SHA by fetching all open PRs and filtering.
func ListAllOpenPullRequestsFilteredBySHA(ctx context.Context, client GitHubPullRequestClient, owner, repo, sha string) ([]*github.PullRequest, error) {
logger := zerolog.Ctx(ctx)
var results []*github.PullRequest
opts := &github.PullRequestListOptions{
State: "open",
ListOptions: github.ListOptions{PerPage: 100},
}

for {
prs, resp, err := client.List(ctx, owner, repo, opts)
if err != nil {
return nil, errors.Wrapf(err, "failed to list pull requests for repository %s/%s", owner, repo)
}

for _, pr := range prs {
if pr.Head.GetSHA() == sha {
logger.Debug().Msgf("found open pull request with sha %s", pr.Head.GetSHA())
results = append(results, pr)
}
}

if resp.NextPage == 0 {
break
}
opts.Page = resp.NextPage
}

return results, nil
}

// GetAllPossibleOpenPullRequestsForSHA attempts to find all open pull requests
// associated with the given SHA using multiple methods in case we are dealing with a fork
func GetAllPossibleOpenPullRequestsForSHA(ctx context.Context, client GitHubPullRequestClient, owner, repo, sha string) ([]*github.PullRequest, error) {
logger := zerolog.Ctx(ctx)

ref = strings.TrimPrefix(ref, "refs/heads/")
prs, err := getOpenPullRequestsForSHA(ctx, client, owner, repo, sha)
if err != nil {
return nil, errors.Wrap(err, "failed to get open pull requests matching the SHA")
}

if len(prs) == 0 {
logger.Debug().Msg("no pull requests found via commit association , searching all pull requests by SHA")
prs, err = ListAllOpenPullRequestsFilteredBySHA(ctx, client, owner, repo, sha)
if err != nil {
return nil, errors.Wrap(err, "failed to list open pull requests matching the SHA")
}
}

return prs, nil
}

// GetAllOpenPullRequestsForRef returns all open pull requests for a given base branch reference.
func GetAllOpenPullRequestsForRef(ctx context.Context, client GitHubPullRequestClient, owner, repo, ref string) ([]*github.PullRequest, error) {
logger := zerolog.Ctx(ctx)
ref = strings.TrimPrefix(ref, "refs/heads/")
opts := &github.PullRequestListOptions{
State: "open",
Base: ref, // Filter by base branch name
ListOptions: github.ListOptions{
PerPage: 100,
},
State: "open",
Base: ref,
ListOptions: github.ListOptions{PerPage: 100},
}

var results []*github.PullRequest
for {
prs, resp, err := client.PullRequests.List(ctx, owner, repoName, opts)
prs, resp, err := client.List(ctx, owner, repo, opts)
if err != nil {
return nil, errors.Wrapf(err, "failed to list pull requests for repository %s/%s", owner, repoName)
return nil, errors.Wrapf(err, "failed to list pull requests for repository %s/%s", owner, repo)
}

for _, pr := range prs {
logger.Debug().Msgf("found open pull request with base ref %s", pr.GetBase().GetRef())
results = append(results, pr)
}

if resp.NextPage == 0 {
break
}
opts.Page = resp.NextPage
}

return results, nil

}
193 changes: 193 additions & 0 deletions pull/pull_requests_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
// Copyright 2024 Palantir Technologies, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// pull_test.go

package pull

import (
"context"
"testing"

"github.com/google/go-github/v66/github"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)

type MockGitHubClient struct {
mock.Mock
}

func (m *MockGitHubClient) ListPullRequestsWithCommit(ctx context.Context, owner, repo, sha string, opts *github.ListOptions) ([]*github.PullRequest, *github.Response, error) {
args := m.Called(ctx, owner, repo, sha, opts)
return args.Get(0).([]*github.PullRequest), args.Get(1).(*github.Response), args.Error(2)
}

func (m *MockGitHubClient) List(ctx context.Context, owner, repo string, opts *github.PullRequestListOptions) ([]*github.PullRequest, *github.Response, error) {
args := m.Called(ctx, owner, repo, opts)
return args.Get(0).([]*github.PullRequest), args.Get(1).(*github.Response), args.Error(2)
}

func TestGetOpenPullRequestsForSHA(t *testing.T) {
mockClient := new(MockGitHubClient)
ctx := context.Background()
owner := "owner"
repo := "repo"
sha := "sha"

pr := &github.PullRequest{
State: github.String("open"),
Head: &github.PullRequestBranch{SHA: github.String(sha)},
}

mockClient.On("ListPullRequestsWithCommit", ctx, owner, repo, sha, mock.Anything).Return([]*github.PullRequest{pr}, &github.Response{NextPage: 0}, nil)

prs, err := getOpenPullRequestsForSHA(ctx, mockClient, owner, repo, sha)
assert.NoError(t, err)
assert.Len(t, prs, 1)
assert.Equal(t, sha, prs[0].GetHead().GetSHA())

mockClient.AssertExpectations(t)
}

func TestListOpenPullRequestsForSHA(t *testing.T) {
mockClient := new(MockGitHubClient)
ctx := context.Background()
owner := "owner"
repo := "repo"
sha := "sha"

pr := &github.PullRequest{
State: github.String("open"),
Head: &github.PullRequestBranch{SHA: github.String(sha)},
}

mockClient.On("List", ctx, owner, repo, mock.Anything).Return([]*github.PullRequest{pr}, &github.Response{NextPage: 0}, nil)

prs, err := ListAllOpenPullRequestsFilteredBySHA(ctx, mockClient, owner, repo, sha)
assert.NoError(t, err)
assert.Len(t, prs, 1)
assert.Equal(t, sha, prs[0].GetHead().GetSHA())

mockClient.AssertExpectations(t)
}

func TestGetAllPossibleOpenPullRequestsForSHA_FirstMethodReturnsResults(t *testing.T) {
mockClient := new(MockGitHubClient)
ctx := context.Background()
owner := "owner"
repo := "repo"
sha := "sha"

pr := &github.PullRequest{
State: github.String("open"),
Head: &github.PullRequestBranch{SHA: github.String(sha)},
}

// Mock the first method to return a valid pull request.
mockClient.On("ListPullRequestsWithCommit", ctx, owner, repo, sha, mock.Anything).Return([]*github.PullRequest{pr}, &github.Response{NextPage: 0}, nil).Once()
// Mock the second method to not be called.
mockClient.On("List", ctx, owner, repo, mock.Anything).Return(nil, nil, nil).Maybe()

prs, err := GetAllPossibleOpenPullRequestsForSHA(ctx, mockClient, owner, repo, sha)
assert.NoError(t, err)
assert.Len(t, prs, 1)
assert.Equal(t, sha, prs[0].GetHead().GetSHA())

mockClient.AssertExpectations(t)
}

func TestGetAllPossibleOpenPullRequestsForSHA_SecondMethodReturnsResults(t *testing.T) {
mockClient := new(MockGitHubClient)
ctx := context.Background()
owner := "owner"
repo := "repo"
sha := "sha"

pr := &github.PullRequest{
State: github.String("open"),
Head: &github.PullRequestBranch{SHA: github.String(sha)},
}

// Mock the first method to return no results.
mockClient.On("ListPullRequestsWithCommit", ctx, owner, repo, sha, mock.Anything).Return([]*github.PullRequest{}, &github.Response{NextPage: 0}, nil).Once()
// Mock the second method to return a valid pull request.
mockClient.On("List", ctx, owner, repo, mock.Anything).Return([]*github.PullRequest{pr}, &github.Response{NextPage: 0}, nil).Once()

prs, err := GetAllPossibleOpenPullRequestsForSHA(ctx, mockClient, owner, repo, sha)
assert.NoError(t, err)
assert.Len(t, prs, 1)
assert.Equal(t, sha, prs[0].GetHead().GetSHA())

mockClient.AssertExpectations(t)
}

func TestGetAllPossibleOpenPullRequestsForSHA_NoResults(t *testing.T) {
mockClient := new(MockGitHubClient)
ctx := context.Background()
owner := "owner"
repo := "repo"
sha := "sha"

// Mock both methods to return no results.
mockClient.On("ListPullRequestsWithCommit", ctx, owner, repo, sha, mock.Anything).Return([]*github.PullRequest{}, &github.Response{NextPage: 0}, nil).Once()
mockClient.On("List", ctx, owner, repo, mock.Anything).Return([]*github.PullRequest{}, &github.Response{NextPage: 0}, nil).Once()

prs, err := GetAllPossibleOpenPullRequestsForSHA(ctx, mockClient, owner, repo, sha)
assert.NoError(t, err)
assert.Len(t, prs, 0)

mockClient.AssertExpectations(t)
}

func TestGetAllPossibleOpenPullRequestsForSHA_Errors(t *testing.T) {
mockClient := new(MockGitHubClient)
ctx := context.Background()
owner := "owner"
repo := "repo"
sha := "sha"

// Mock the first method to return an error.
mockClient.On("ListPullRequestsWithCommit", ctx, owner, repo, sha, mock.Anything).Return([]*github.PullRequest{}, &github.Response{}, assert.AnError).Once()
// Mock the second method to not be called.
mockClient.On("List", ctx, owner, repo, mock.Anything).Return(nil, nil, nil).Maybe()

prs, err := GetAllPossibleOpenPullRequestsForSHA(ctx, mockClient, owner, repo, sha)
assert.Error(t, err)
assert.Nil(t, prs)

mockClient.AssertExpectations(t)
}

func TestListOpenPullRequestsForRef(t *testing.T) {
mockClient := new(MockGitHubClient)
ctx := context.Background()
owner := "owner"
repo := "repo"
ref := "refs/heads/main"

pr := &github.PullRequest{
State: github.String("open"),
Base: &github.PullRequestBranch{Ref: github.String("main")},
}

mockClient.On("List", ctx, owner, repo, mock.Anything).Return([]*github.PullRequest{pr}, &github.Response{NextPage: 0}, nil)

prs, err := GetAllOpenPullRequestsForRef(ctx, mockClient, owner, repo, ref)
assert.NoError(t, err)
assert.Len(t, prs, 1)
assert.Equal(t, "main", prs[0].GetBase().GetRef())

mockClient.AssertExpectations(t)
}
15 changes: 13 additions & 2 deletions server/handler/check_run.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ func (h *CheckRun) Handle(ctx context.Context, eventType, deliveryID string, pay
}

repo := event.GetRepo()
owner := repo.GetOwner().GetLogin()
repoName := repo.GetName()

installationID := githubapp.GetInstallationIDFromEvent(&event)

ctx, logger := githubapp.PrepareRepoContext(ctx, installationID, repo)
Expand All @@ -56,8 +59,16 @@ func (h *CheckRun) Handle(ctx context.Context, eventType, deliveryID string, pay

prs := event.GetCheckRun().PullRequests
if len(prs) == 0 {
logger.Debug().Msg("Doing nothing since status change event affects no open pull requests")
return nil
logger.Debug().Msg("No pull requests associated with the check run, searching by SHA")
// check runs on fork PRs do not have the PRs attached to the event so we need to filter all PRs by SHA
prs, err = pull.ListAllOpenPullRequestsFilteredBySHA(ctx, client.PullRequests, owner, repoName, event.GetCheckRun().GetHeadSHA())
if err != nil {
return errors.Wrap(err, "failed to determine open pull requests matching the status context change")
}
if len(prs) == 0 {
logger.Debug().Msg("No open pull requests found for the given SHA")
return nil
}
}

for _, pr := range prs {
Expand Down
Loading