diff --git a/internal/commands/util/usercount/github_test.go b/internal/commands/util/usercount/github_test.go index b39bc403d..3e17e1eab 100644 --- a/internal/commands/util/usercount/github_test.go +++ b/internal/commands/util/usercount/github_test.go @@ -1,13 +1,20 @@ package usercount import ( + "io" + "net/http" + "strconv" "strings" "testing" + "time" + + asserts "github.com/stretchr/testify/assert" + "gotest.tools/assert" "github.com/checkmarx/ast-cli/internal/commands/util/printer" "github.com/checkmarx/ast-cli/internal/params" + "github.com/checkmarx/ast-cli/internal/wrappers" "github.com/checkmarx/ast-cli/internal/wrappers/mock" - "gotest.tools/assert" ) func TestGitHubUserCountOrgs(t *testing.T) { @@ -100,3 +107,56 @@ func TestGitHubUserCountManyOrgs(t *testing.T) { err := cmd.Execute() assert.Error(t, err, tooManyOrgs) } + +func TestHandleRateLimit_WaitsAndRetries(t *testing.T) { + resp := &http.Response{ + StatusCode: http.StatusForbidden, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader("")), + } + resp.Header.Set("X-RateLimit-Remaining", "0") + resetTime := time.Now().Add(50 * time.Second).Unix() + resp.Header.Set("X-RateLimit-Reset", strconv.FormatInt(resetTime, 10)) + defer func() { + if err := resp.Body.Close(); err != nil { + t.Fatal(err) + } + }() + client := &http.Client{} + req, _ := http.NewRequest(http.MethodGet, "http://example.com", http.NoBody) + + start := time.Now() + outResp, err := wrappers.HandleRateLimit(resp, client, req, "http://example.com", "token", map[string]string{}) + defer func() { + if err := outResp.Body.Close(); err != nil { + t.Fatal(err) + } + }() + elapsed := time.Since(start) + + asserts.NoError(t, err) + asserts.GreaterOrEqual(t, elapsed, 20*time.Second) +} + +func TestHandleRateLimit_NoRateLimit(t *testing.T) { + resp := &http.Response{ + StatusCode: http.StatusForbidden, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader("")), + } + defer func() { + if err := resp.Body.Close(); err != nil { + t.Fatal(err) + } + }() + client := &http.Client{} + req, _ := http.NewRequest(http.MethodGet, "http://example.com", http.NoBody) + outResp, err := wrappers.HandleRateLimit(resp, client, req, "http://example.com", "token", map[string]string{}) + defer func() { + if err := outResp.Body.Close(); err != nil { + t.Fatal(err) + } + }() + asserts.NoError(t, err) + assert.Equal(t, resp, outResp) +} diff --git a/internal/wrappers/github-http.go b/internal/wrappers/github-http.go index 6ba2d31e0..cf4345ab6 100644 --- a/internal/wrappers/github-http.go +++ b/internal/wrappers/github-http.go @@ -5,7 +5,9 @@ import ( "fmt" "io" "net/http" + "strconv" "strings" + "time" "github.com/checkmarx/ast-cli/internal/logger" "github.com/checkmarx/ast-cli/internal/params" @@ -248,6 +250,12 @@ func get(client *http.Client, url string, target interface{}, queryParams map[st if err != nil { return nil, err } + if resp.StatusCode == http.StatusForbidden { + resp, err = HandleRateLimit(resp, client, req, url, token, queryParams) + if err != nil { + return nil, err + } + } defer func() { if err == nil { _ = resp.Body.Close() @@ -276,3 +284,19 @@ func get(client *http.Client, url string, target interface{}, queryParams map[st } return resp, nil } + +func HandleRateLimit(resp *http.Response, client *http.Client, req *http.Request, url, token string, queryParams map[string]string) (*http.Response, error) { + remaining := resp.Header.Get("X-RateLimit-Remaining") + reset := resp.Header.Get("X-RateLimit-Reset") + if remaining == "0" && reset != "" { + resetUnix, err := strconv.ParseInt(reset, 10, 64) + if err == nil { + waitDuration := time.Until(time.Unix(resetUnix, 0)) + if waitDuration > 0 { + time.Sleep(waitDuration) + return GetWithQueryParamsAndCustomRequest(client, req, url, token, tokenFormat, queryParams) // Indicate to retry + } + } + } + return resp, nil +}