From f35f98d5ddcb46906f6304f7ef4f89956511bc8f Mon Sep 17 00:00:00 2001 From: Alan Patel Date: Wed, 30 Oct 2024 13:53:49 -0700 Subject: [PATCH] Filter pull requests pulled by Bulldozer (#559) * Filter pull requests pulled by Bulldozer * remove prefix check --------- Co-authored-by: alanpatel --- pull/pull_requests.go | 49 ++++++++++++++++--------------------------- 1 file changed, 18 insertions(+), 31 deletions(-) diff --git a/pull/pull_requests.go b/pull/pull_requests.go index 4bf5c5343..b23bfe1ca 100644 --- a/pull/pull_requests.go +++ b/pull/pull_requests.go @@ -16,7 +16,7 @@ package pull import ( "context" - "fmt" + "strings" "github.com/google/go-github/v65/github" "github.com/pkg/errors" @@ -26,20 +26,22 @@ import ( // ListOpenPullRequestsForSHA returns all pull requests where the HEAD of the source branch // in the pull request matches the given SHA. func ListOpenPullRequestsForSHA(ctx context.Context, client *github.Client, owner, repoName, SHA string) ([]*github.PullRequest, error) { - var results []*github.PullRequest - - openPRs, err := ListOpenPullRequests(ctx, client, owner, repoName) - + prs, _, err := client.PullRequests.ListPullRequestsWithCommit(ctx, owner, repoName, SHA, &github.ListOptions{ + // In practice, there should be at most 1-3 PRs for a given commit. In + // exceptional cases, if there are more than 100 PRs, we'll only + // consider the first 100 to avoid paging. + PerPage: 100, + }) if err != nil { - return nil, err + return nil, errors.Wrapf(err, "failed to list pull requests for repository %s/%s", owner, repoName) } - for _, openPR := range openPRs { - if openPR.Head.GetSHA() == SHA { - results = append(results, openPR) + var results []*github.PullRequest + for _, pr := range prs { + if pr.GetState() == "open" && pr.GetHead().GetSHA() == SHA { + results = append(results, pr) } } - return results, nil } @@ -47,28 +49,11 @@ func ListOpenPullRequestsForRef(ctx context.Context, client *github.Client, owne var results []*github.PullRequest logger := zerolog.Ctx(ctx) - openPRs, err := ListOpenPullRequests(ctx, client, owner, repoName) - - if err != nil { - return nil, err - } - - for _, openPR := range openPRs { - formattedRef := fmt.Sprintf("refs/heads/%s", openPR.GetBase().GetRef()) - logger.Debug().Msgf("found open pull request with base ref %s", formattedRef) - if formattedRef == ref { - results = append(results, openPR) - } - } - - return results, nil -} - -func ListOpenPullRequests(ctx context.Context, client *github.Client, owner, repoName string) ([]*github.PullRequest, error) { - var results []*github.PullRequest + ref = strings.TrimPrefix(ref, "refs/heads/") opts := &github.PullRequestListOptions{ State: "open", + Base: ref, // Filter by base branch name ListOptions: github.ListOptions{ PerPage: 100, }, @@ -77,16 +62,18 @@ func ListOpenPullRequests(ctx context.Context, client *github.Client, owner, rep for { prs, resp, err := client.PullRequests.List(ctx, owner, repoName, opts) if err != nil { - return results, errors.Wrapf(err, "failed to list pull requests for repository %s/%s", owner, repoName) + return nil, errors.Wrapf(err, "failed to list pull requests for repository %s/%s", owner, repoName) } for _, pr := range prs { + logger.Debug().Msgf("found open pull request with base ref %s", pr.GetBase().GetRef()) results = append(results, pr) } if resp.NextPage == 0 { break } - opts.ListOptions.Page = resp.NextPage + opts.Page = resp.NextPage } return results, nil + }