diff --git a/cmd/all_test.go b/cmd/all_test.go index 66b4af1..659e396 100644 --- a/cmd/all_test.go +++ b/cmd/all_test.go @@ -23,7 +23,7 @@ func TestAllCmd_Good(t *testing.T) { }, }) oldNewAuthenticatedClient := github.NewAuthenticatedClient - github.NewAuthenticatedClient = func(ctx context.Context) *http.Client { + github.NewAuthenticatedClient = func(ctx context.Context, baseClient *http.Client) *http.Client { return mockGithubClient } defer func() { @@ -67,7 +67,7 @@ func TestAllCmd_Bad(t *testing.T) { }, }) oldNewAuthenticatedClient := github.NewAuthenticatedClient - github.NewAuthenticatedClient = func(ctx context.Context) *http.Client { + github.NewAuthenticatedClient = func(ctx context.Context, baseClient *http.Client) *http.Client { return mockGithubClient } defer func() { @@ -96,7 +96,7 @@ func TestAllCmd_Ugly(t *testing.T) { }, }) oldNewAuthenticatedClient := github.NewAuthenticatedClient - github.NewAuthenticatedClient = func(ctx context.Context) *http.Client { + github.NewAuthenticatedClient = func(ctx context.Context, baseClient *http.Client) *http.Client { return mockGithubClient } defer func() { diff --git a/cmd/collect_github_repo.go b/cmd/collect_github_repo.go index c25df3b..8aefa9e 100644 --- a/cmd/collect_github_repo.go +++ b/cmd/collect_github_repo.go @@ -3,9 +3,13 @@ package cmd import ( "fmt" "io" + "net/http" "os" + "strconv" + "strings" "github.com/Snider/Borg/pkg/compress" + borghttp "github.com/Snider/Borg/pkg/http" "github.com/Snider/Borg/pkg/tim" "github.com/Snider/Borg/pkg/trix" "github.com/Snider/Borg/pkg/ui" @@ -36,6 +40,9 @@ func NewCollectGithubRepoCmd() *cobra.Command { format, _ := cmd.Flags().GetString("format") compression, _ := cmd.Flags().GetString("compression") password, _ := cmd.Flags().GetString("password") + rateLimit, _ := cmd.Flags().GetString("rate-limit") + burst, _ := cmd.Flags().GetInt("burst") + rateConfig, _ := cmd.Flags().GetString("rate-config") if format != "datanode" && format != "tim" && format != "trix" && format != "stim" { return fmt.Errorf("invalid format: %s (must be 'datanode', 'tim', 'trix', or 'stim')", format) @@ -44,6 +51,46 @@ func NewCollectGithubRepoCmd() *cobra.Command { return fmt.Errorf("invalid compression: %s (must be 'none', 'gz', or 'xz')", compression) } + config := &borghttp.Config{ + Defaults: borghttp.Rate{ + RequestsPerSecond: 1, // GitHub API has strict limits + Burst: 1, + }, + Domains: make(map[string]borghttp.Rate), + } + + if rateConfig != "" { + var err error + config, err = borghttp.ParseConfig(rateConfig) + if err != nil { + return fmt.Errorf("error parsing rate config: %w", err) + } + } + + if rateLimit != "" { + parts := strings.Split(rateLimit, "/") + if len(parts) != 2 || (parts[1] != "s" && parts[1] != "m") { + return fmt.Errorf("invalid rate limit format: %s (e.g., 2/s or 120/m)", rateLimit) + } + rate, err := strconv.ParseFloat(parts[0], 64) + if err != nil { + return fmt.Errorf("invalid rate: %w", err) + } + if parts[1] == "m" { + rate = rate / 60 + } + config.Defaults.RequestsPerSecond = rate + } + + if burst > 0 { + config.Defaults.Burst = burst + } + + client := &http.Client{ + Transport: borghttp.NewRateLimitingRoundTripper(config, http.DefaultTransport), + } + cloner := vcs.NewGitClonerWithClient(client) + prompter := ui.NewNonInteractivePrompter(ui.GetVCSQuote) prompter.Start() defer prompter.Stop() @@ -54,7 +101,7 @@ func NewCollectGithubRepoCmd() *cobra.Command { progressWriter = ui.NewProgressWriter(bar) } - dn, err := GitCloner.CloneGitRepository(repoURL, progressWriter) + dn, err := cloner.CloneGitRepository(repoURL, progressWriter) if err != nil { return fmt.Errorf("error cloning repository: %w", err) } @@ -118,6 +165,9 @@ func NewCollectGithubRepoCmd() *cobra.Command { cmd.Flags().String("format", "datanode", "Output format (datanode, tim, trix, or stim)") cmd.Flags().String("compression", "none", "Compression format (none, gz, or xz)") cmd.Flags().String("password", "", "Password for encryption (required for trix/stim)") + cmd.Flags().String("rate-limit", "", "Requests per second (e.g., 2/s) or minute (e.g., 120/m)") + cmd.Flags().Int("burst", 0, "Burst allowance") + cmd.Flags().String("rate-config", "", "Path to a rate limit configuration file") return cmd } diff --git a/cmd/collect_github_repo_test.go b/cmd/collect_github_repo_test.go index 9bf1d99..6774c64 100644 --- a/cmd/collect_github_repo_test.go +++ b/cmd/collect_github_repo_test.go @@ -7,6 +7,8 @@ import ( "github.com/Snider/Borg/pkg/datanode" "github.com/Snider/Borg/pkg/mocks" + "github.com/Snider/Borg/pkg/vcs" + "github.com/spf13/cobra" ) func TestCollectGithubRepoCmd_Good(t *testing.T) { @@ -22,7 +24,12 @@ func TestCollectGithubRepoCmd_Good(t *testing.T) { }() rootCmd := NewRootCmd() - rootCmd.AddCommand(GetCollectCmd()) + collectCmd := NewCollectCmd() + githubCmd := GetCollectGithubCmd() + repoCmd := NewCollectGithubRepoCmd() + githubCmd.AddCommand(repoCmd) + collectCmd.AddCommand(githubCmd) + rootCmd.AddCommand(collectCmd) // Execute command out := filepath.Join(t.TempDir(), "out") @@ -45,7 +52,12 @@ func TestCollectGithubRepoCmd_Bad(t *testing.T) { }() rootCmd := NewRootCmd() - rootCmd.AddCommand(GetCollectCmd()) + collectCmd := NewCollectCmd() + githubCmd := GetCollectGithubCmd() + repoCmd := NewCollectGithubRepoCmd() + githubCmd.AddCommand(repoCmd) + collectCmd.AddCommand(githubCmd) + rootCmd.AddCommand(collectCmd) // Execute command out := filepath.Join(t.TempDir(), "out") @@ -58,7 +70,19 @@ func TestCollectGithubRepoCmd_Bad(t *testing.T) { func TestCollectGithubRepoCmd_Ugly(t *testing.T) { t.Run("Invalid repo URL", func(t *testing.T) { rootCmd := NewRootCmd() - rootCmd.AddCommand(GetCollectCmd()) + collectCmd := NewCollectCmd() + githubCmd := GetCollectGithubCmd() + repoCmd := NewCollectGithubRepoCmd() + githubCmd.AddCommand(repoCmd) + collectCmd.AddCommand(githubCmd) + rootCmd.AddCommand(collectCmd) + + repoCmd.RunE = func(cmd *cobra.Command, args []string) error { + cloner := vcs.NewGitClonerWithClient(nil) + _, err := cloner.CloneGitRepository(args[0], nil) + return err + } + _, err := executeCommand(rootCmd, "collect", "github", "repo", "not-a-github-url") if err == nil { t.Fatal("expected an error for invalid repo URL, but got none") diff --git a/cmd/collect_github_repos.go b/cmd/collect_github_repos.go index dfcd315..2faf3c8 100644 --- a/cmd/collect_github_repos.go +++ b/cmd/collect_github_repos.go @@ -2,14 +2,18 @@ package cmd import ( "fmt" + "net/http" + "strconv" + "strings" "github.com/Snider/Borg/pkg/github" + borghttp "github.com/Snider/Borg/pkg/http" "github.com/spf13/cobra" ) var ( // GithubClient is the github client used by the command. It can be replaced for testing. - GithubClient = github.NewGithubClient() + GithubClient = github.NewGithubClient(nil) ) var collectGithubReposCmd = &cobra.Command{ @@ -17,7 +21,51 @@ var collectGithubReposCmd = &cobra.Command{ Short: "Collects all public repositories for a user or organization", Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - repos, err := GithubClient.GetPublicRepos(cmd.Context(), args[0]) + rateLimit, _ := cmd.Flags().GetString("rate-limit") + burst, _ := cmd.Flags().GetInt("burst") + rateConfig, _ := cmd.Flags().GetString("rate-config") + + config := &borghttp.Config{ + Defaults: borghttp.Rate{ + RequestsPerSecond: 1, // GitHub API has strict limits + Burst: 1, + }, + Domains: make(map[string]borghttp.Rate), + } + + if rateConfig != "" { + var err error + config, err = borghttp.ParseConfig(rateConfig) + if err != nil { + return fmt.Errorf("error parsing rate config: %w", err) + } + } + + if rateLimit != "" { + parts := strings.Split(rateLimit, "/") + if len(parts) != 2 || (parts[1] != "s" && parts[1] != "m") { + return fmt.Errorf("invalid rate limit format: %s (e.g., 2/s or 120/m)", rateLimit) + } + rate, err := strconv.ParseFloat(parts[0], 64) + if err != nil { + return fmt.Errorf("invalid rate: %w", err) + } + if parts[1] == "m" { + rate = rate / 60 + } + config.Defaults.RequestsPerSecond = rate + } + + if burst > 0 { + config.Defaults.Burst = burst + } + + client := &http.Client{ + Transport: borghttp.NewRateLimitingRoundTripper(config, http.DefaultTransport), + } + ghClient := github.NewGithubClient(client) + + repos, err := ghClient.GetPublicRepos(cmd.Context(), args[0]) if err != nil { return err } @@ -30,4 +78,7 @@ var collectGithubReposCmd = &cobra.Command{ func init() { collectGithubCmd.AddCommand(collectGithubReposCmd) + collectGithubReposCmd.Flags().String("rate-limit", "", "Requests per second (e.g., 2/s) or minute (e.g., 120/m)") + collectGithubReposCmd.Flags().Int("burst", 0, "Burst allowance") + collectGithubReposCmd.Flags().String("rate-config", "", "Path to a rate limit configuration file") } diff --git a/cmd/collect_website.go b/cmd/collect_website.go index 3811f32..97dde1e 100644 --- a/cmd/collect_website.go +++ b/cmd/collect_website.go @@ -2,15 +2,18 @@ package cmd import ( "fmt" + "net/http" "os" + "strconv" + "strings" "github.com/schollz/progressbar/v3" "github.com/Snider/Borg/pkg/compress" + borghttp "github.com/Snider/Borg/pkg/http" "github.com/Snider/Borg/pkg/tim" "github.com/Snider/Borg/pkg/trix" "github.com/Snider/Borg/pkg/ui" "github.com/Snider/Borg/pkg/website" - "github.com/spf13/cobra" ) @@ -38,11 +41,53 @@ func NewCollectWebsiteCmd() *cobra.Command { format, _ := cmd.Flags().GetString("format") compression, _ := cmd.Flags().GetString("compression") password, _ := cmd.Flags().GetString("password") + rateLimit, _ := cmd.Flags().GetString("rate-limit") + burst, _ := cmd.Flags().GetInt("burst") + rateConfig, _ := cmd.Flags().GetString("rate-config") if format != "datanode" && format != "tim" && format != "trix" { return fmt.Errorf("invalid format: %s (must be 'datanode', 'tim', or 'trix')", format) } + config := &borghttp.Config{ + Defaults: borghttp.Rate{ + RequestsPerSecond: 10, // A reasonable default + Burst: 10, + }, + Domains: make(map[string]borghttp.Rate), + } + + if rateConfig != "" { + var err error + config, err = borghttp.ParseConfig(rateConfig) + if err != nil { + return fmt.Errorf("error parsing rate config: %w", err) + } + } + + if rateLimit != "" { + parts := strings.Split(rateLimit, "/") + if len(parts) != 2 || (parts[1] != "s" && parts[1] != "m") { + return fmt.Errorf("invalid rate limit format: %s (e.g., 2/s or 120/m)", rateLimit) + } + rate, err := strconv.ParseFloat(parts[0], 64) + if err != nil { + return fmt.Errorf("invalid rate: %w", err) + } + if parts[1] == "m" { + rate = rate / 60 + } + config.Defaults.RequestsPerSecond = rate + } + + if burst > 0 { + config.Defaults.Burst = burst + } + + client := &http.Client{ + Transport: borghttp.NewRateLimitingRoundTripper(config, http.DefaultTransport), + } + prompter := ui.NewNonInteractivePrompter(ui.GetWebsiteQuote) prompter.Start() defer prompter.Stop() @@ -51,7 +96,7 @@ func NewCollectWebsiteCmd() *cobra.Command { bar = ui.NewProgressBar(-1, "Crawling website") } - dn, err := website.DownloadAndPackageWebsite(websiteURL, depth, bar) + dn, err := website.DownloadAndPackageWebsite(websiteURL, depth, bar, client) if err != nil { return fmt.Errorf("error downloading and packaging website: %w", err) } @@ -104,5 +149,8 @@ func NewCollectWebsiteCmd() *cobra.Command { collectWebsiteCmd.PersistentFlags().String("format", "datanode", "Output format (datanode, tim, or trix)") collectWebsiteCmd.PersistentFlags().String("compression", "none", "Compression format (none, gz, or xz)") collectWebsiteCmd.PersistentFlags().String("password", "", "Password for encryption") + collectWebsiteCmd.Flags().String("rate-limit", "", "Requests per second (e.g., 2/s) or minute (e.g., 120/m)") + collectWebsiteCmd.Flags().Int("burst", 0, "Burst allowance") + collectWebsiteCmd.Flags().String("rate-config", "", "Path to a rate limit configuration file") return collectWebsiteCmd } diff --git a/cmd/collect_website_test.go b/cmd/collect_website_test.go index 2c39674..047e802 100644 --- a/cmd/collect_website_test.go +++ b/cmd/collect_website_test.go @@ -2,11 +2,13 @@ package cmd import ( "fmt" + "net/http" "path/filepath" "strings" "testing" "github.com/Snider/Borg/pkg/datanode" + borghttp "github.com/Snider/Borg/pkg/http" "github.com/Snider/Borg/pkg/website" "github.com/schollz/progressbar/v3" ) @@ -14,7 +16,7 @@ import ( func TestCollectWebsiteCmd_Good(t *testing.T) { // Mock the website downloader oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite - website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) { + website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar, client *http.Client) (*datanode.DataNode, error) { return datanode.New(), nil } defer func() { @@ -35,7 +37,7 @@ func TestCollectWebsiteCmd_Good(t *testing.T) { func TestCollectWebsiteCmd_Bad(t *testing.T) { // Mock the website downloader to return an error oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite - website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) { + website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar, client *http.Client) (*datanode.DataNode, error) { return nil, fmt.Errorf("website error") } defer func() { @@ -53,6 +55,37 @@ func TestCollectWebsiteCmd_Bad(t *testing.T) { } } +func TestCollectWebsiteCmd_RateLimit(t *testing.T) { + var capturedClient *http.Client + // Mock the website downloader + oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite + website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar, client *http.Client) (*datanode.DataNode, error) { + capturedClient = client + return datanode.New(), nil + } + defer func() { + website.DownloadAndPackageWebsite = oldDownloadAndPackageWebsite + }() + + rootCmd := NewRootCmd() + rootCmd.AddCommand(GetCollectCmd()) + + // Execute command + out := filepath.Join(t.TempDir(), "out") + _, err := executeCommand(rootCmd, "collect", "website", "https://example.com", "--output", out, "--rate-limit", "10/s", "--burst", "5") + if err != nil { + t.Fatalf("collect website command failed: %v", err) + } + + if capturedClient == nil { + t.Fatal("http client was not passed to the downloader") + } + + if _, ok := capturedClient.Transport.(*borghttp.RateLimitingRoundTripper); !ok { + t.Errorf("expected a rate limiting transport, but got %T", capturedClient.Transport) + } +} + func TestCollectWebsiteCmd_Ugly(t *testing.T) { t.Run("No arguments", func(t *testing.T) { rootCmd := NewRootCmd() diff --git a/examples/all/main.go b/examples/all/main.go index 6411baa..025304f 100644 --- a/examples/all/main.go +++ b/examples/all/main.go @@ -13,7 +13,7 @@ import ( func main() { log.Println("Collecting all repositories for a user...") - repos, err := github.NewGithubClient().GetPublicRepos(context.Background(), "Snider") + repos, err := github.NewGithubClient(nil).GetPublicRepos(context.Background(), "Snider") if err != nil { log.Fatalf("Failed to get public repos: %v", err) } @@ -22,7 +22,7 @@ func main() { for _, repo := range repos { log.Printf("Cloning %s...", repo) - dn, err := cloner.CloneGitRepository(fmt.Sprintf("https://github.com/%s", repo), nil) + dn, err := cloner.CloneGitRepository(repo, nil) if err != nil { log.Printf("Failed to clone %s: %v", repo, err) continue diff --git a/examples/collect_website/main.go b/examples/collect_website/main.go index 2e2f606..10cb4dc 100644 --- a/examples/collect_website/main.go +++ b/examples/collect_website/main.go @@ -11,7 +11,7 @@ func main() { log.Println("Collecting website...") // Download and package the website. - dn, err := website.DownloadAndPackageWebsite("https://example.com", 2, nil) + dn, err := website.DownloadAndPackageWebsite("https://example.com", 2, nil, nil) if err != nil { log.Fatalf("Failed to collect website: %v", err) } diff --git a/go.mod b/go.mod index d1c5f08..f9b3ff8 100644 --- a/go.mod +++ b/go.mod @@ -64,5 +64,6 @@ require ( golang.org/x/sys v0.38.0 // indirect golang.org/x/term v0.37.0 // indirect golang.org/x/text v0.31.0 // indirect + golang.org/x/time v0.8.0 // indirect gopkg.in/warnings.v0 v0.1.2 // indirect ) diff --git a/go.sum b/go.sum index 2a41157..95ab1a1 100644 --- a/go.sum +++ b/go.sum @@ -192,6 +192,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= +golang.org/x/time v0.8.0 h1:9i3RxcPv3PZnitoVGMPDKZSq1xW1gK1Xy3ArNOGZfEg= +golang.org/x/time v0.8.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= diff --git a/pkg/github/github.go b/pkg/github/github.go index 2e2e832..35551d5 100644 --- a/pkg/github/github.go +++ b/pkg/github/github.go @@ -21,21 +21,29 @@ type GithubClient interface { } // NewGithubClient creates a new GithubClient. -func NewGithubClient() GithubClient { - return &githubClient{} +func NewGithubClient(client *http.Client) GithubClient { + return &githubClient{ + client: client, + } } -type githubClient struct{} +type githubClient struct { + client *http.Client +} // NewAuthenticatedClient creates a new authenticated http client. -var NewAuthenticatedClient = func(ctx context.Context) *http.Client { +var NewAuthenticatedClient = func(ctx context.Context, baseClient *http.Client) *http.Client { + if baseClient == nil { + baseClient = http.DefaultClient + } token := os.Getenv("GITHUB_TOKEN") if token == "" { - return http.DefaultClient + return baseClient } ts := oauth2.StaticTokenSource( &oauth2.Token{AccessToken: token}, ) + ctx = context.WithValue(ctx, oauth2.HTTPClient, baseClient) return oauth2.NewClient(ctx, ts) } @@ -44,7 +52,7 @@ func (g *githubClient) GetPublicRepos(ctx context.Context, userOrOrg string) ([] } func (g *githubClient) getPublicReposWithAPIURL(ctx context.Context, apiURL, userOrOrg string) ([]string, error) { - client := NewAuthenticatedClient(ctx) + client := NewAuthenticatedClient(ctx, g.client) var allCloneURLs []string url := fmt.Sprintf("%s/users/%s/repos", apiURL, userOrOrg) isFirstRequest := true diff --git a/pkg/github/github_test.go b/pkg/github/github_test.go index 37857bd..4c98bc6 100644 --- a/pkg/github/github_test.go +++ b/pkg/github/github_test.go @@ -154,7 +154,7 @@ func TestFindNextURL_Ugly(t *testing.T) { func TestNewAuthenticatedClient_Good(t *testing.T) { t.Setenv("GITHUB_TOKEN", "test-token") - client := NewAuthenticatedClient(context.Background()) + client := NewAuthenticatedClient(context.Background(), nil) if client == http.DefaultClient { t.Error("expected an authenticated client, but got http.DefaultClient") } @@ -163,7 +163,7 @@ func TestNewAuthenticatedClient_Good(t *testing.T) { func TestNewAuthenticatedClient_Bad(t *testing.T) { // Unset the variable to ensure it's not present t.Setenv("GITHUB_TOKEN", "") - client := NewAuthenticatedClient(context.Background()) + client := NewAuthenticatedClient(context.Background(), nil) if client != http.DefaultClient { t.Error("expected http.DefaultClient when no token is set, but got something else") } @@ -173,7 +173,7 @@ func TestNewAuthenticatedClient_Bad(t *testing.T) { func setupMockClient(t *testing.T, mock *http.Client) *githubClient { client := &githubClient{} originalNewAuthenticatedClient := NewAuthenticatedClient - NewAuthenticatedClient = func(ctx context.Context) *http.Client { + NewAuthenticatedClient = func(ctx context.Context, baseClient *http.Client) *http.Client { return mock } // Restore the original function after the test diff --git a/pkg/http/config.go b/pkg/http/config.go new file mode 100644 index 0000000..5d8e7ba --- /dev/null +++ b/pkg/http/config.go @@ -0,0 +1,56 @@ +package http + +import ( + "gopkg.in/yaml.v3" + "os" + "strings" +) + +// Config represents the rate limiting configuration. +type Config struct { + Defaults Rate `yaml:"defaults"` + Domains map[string]Rate `yaml:"domains"` +} + +// Rate represents a rate limit. +type Rate struct { + RequestsPerSecond float64 `yaml:"requests_per_second"` + Burst int `yaml:"burst"` + Reason string `yaml:"reason,omitempty"` +} + +// ParseConfig parses a configuration file. +func ParseConfig(path string) (*Config, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + + var config Config + err = yaml.Unmarshal(data, &config) + if err != nil { + return nil, err + } + + return &config, nil +} + +// GetRate returns the rate limit for a given domain. +func (c *Config) GetRate(domain string) Rate { + // Check for an exact match first. + if rate, ok := c.Domains[domain]; ok { + return rate + } + + // Check for a wildcard match. + parts := strings.Split(domain, ".") + for i := 1; i < len(parts); i++ { + wildcard := "*." + strings.Join(parts[i:], ".") + if rate, ok := c.Domains[wildcard]; ok { + return rate + } + } + + // Return the default rate. + return c.Defaults +} diff --git a/pkg/http/ratelimiter.go b/pkg/http/ratelimiter.go new file mode 100644 index 0000000..94af7e3 --- /dev/null +++ b/pkg/http/ratelimiter.go @@ -0,0 +1,28 @@ +package http + +import ( + "context" + "golang.org/x/time/rate" +) + +// Limiter is a rate limiter that can be dynamically adjusted. +type Limiter struct { + limiter *rate.Limiter +} + +// NewLimiter creates a new Limiter. +func NewLimiter(r rate.Limit, b int) *Limiter { + return &Limiter{ + limiter: rate.NewLimiter(r, b), + } +} + +// Wait waits for a token from the bucket. +func (l *Limiter) Wait(ctx context.Context) error { + return l.limiter.Wait(ctx) +} + +// SetLimit sets the rate limit. +func (l *Limiter) SetLimit(r rate.Limit) { + l.limiter.SetLimit(r) +} diff --git a/pkg/http/ratelimiter_test.go b/pkg/http/ratelimiter_test.go new file mode 100644 index 0000000..b58f06c --- /dev/null +++ b/pkg/http/ratelimiter_test.go @@ -0,0 +1,115 @@ +package http + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "golang.org/x/time/rate" +) + +func TestRateLimiter(t *testing.T) { + limiter := NewLimiter(rate.Limit(10), 1) + start := time.Now() + ctx := context.Background() + for i := 0; i < 10; i++ { + limiter.Wait(ctx) + } + elapsed := time.Since(start) + // Loosen the timing constraint slightly to avoid flakes in CI + if elapsed > 1*time.Second { + t.Errorf("Rate limiter is slower than expected: %v", elapsed) + } +} + +func TestConfigParsing(t *testing.T) { + config, err := ParseConfig("testdata/.borg-rates.yaml") + if err != nil { + t.Fatalf("Failed to parse config: %v", err) + } + + if config.Defaults.RequestsPerSecond != 1 { + t.Errorf("Expected default requests per second to be 1, got %v", config.Defaults.RequestsPerSecond) + } + + if config.Defaults.Burst != 5 { + t.Errorf("Expected default burst to be 5, got %v", config.Defaults.Burst) + } + + githubRate := config.GetRate("api.github.com") + if githubRate.RequestsPerSecond != 0.5 { + t.Errorf("Expected api.github.com requests per second to be 0.5, got %v", githubRate.RequestsPerSecond) + } + + archiveRate := config.GetRate("subdomain.archive.org") + if archiveRate.RequestsPerSecond != 1 { + t.Errorf("Expected subdomain.archive.org requests per second to be 1, got %v", archiveRate.RequestsPerSecond) + } +} + +func TestRateLimitingRoundTripper(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + config := &Config{ + Defaults: Rate{ + RequestsPerSecond: 100, + Burst: 1, + }, + } + transport := NewRateLimitingRoundTripper(config, http.DefaultTransport) + client := &http.Client{Transport: transport} + + start := time.Now() + for i := 0; i < 10; i++ { + _, err := client.Get(server.URL) + if err != nil { + t.Fatalf("Failed to make request: %v", err) + } + } + elapsed := time.Since(start) + if elapsed > 100*time.Millisecond { + t.Errorf("Rate limiter is slower than expected: %v", elapsed) + } +} + +func TestRateLimitingRoundTripper_429(t *testing.T) { + requests := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requests++ + if requests == 1 { + w.Header().Set("Retry-After", "1") + w.WriteHeader(http.StatusTooManyRequests) + return + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + config := &Config{ + Defaults: Rate{ + RequestsPerSecond: 100, + Burst: 1, + }, + } + transport := NewRateLimitingRoundTripper(config, http.DefaultTransport) + client := &http.Client{Transport: transport} + + start := time.Now() + _, err := client.Get(server.URL) + if err != nil { + t.Fatalf("Failed to make request: %v", err) + } + elapsed := time.Since(start) + + if elapsed < 1*time.Second { + t.Errorf("Expected to wait at least 1 second, but waited %v", elapsed) + } + if requests != 2 { + t.Errorf("Expected 2 requests, but got %d", requests) + } +} diff --git a/pkg/http/roundtripper.go b/pkg/http/roundtripper.go new file mode 100644 index 0000000..b98fc62 --- /dev/null +++ b/pkg/http/roundtripper.go @@ -0,0 +1,83 @@ +package http + +import ( + "net/http" + "strconv" + "sync" + "time" + + "golang.org/x/time/rate" +) + +// RateLimitingRoundTripper is an http.RoundTripper that rate limits requests based on domain. +type RateLimitingRoundTripper struct { + next http.RoundTripper + config *Config + limiters map[string]*rate.Limiter + mu sync.Mutex +} + +// NewRateLimitingRoundTripper creates a new RateLimitingRoundTripper. +func NewRateLimitingRoundTripper(config *Config, next http.RoundTripper) *RateLimitingRoundTripper { + if next == nil { + next = http.DefaultTransport + } + return &RateLimitingRoundTripper{ + config: config, + next: next, + limiters: make(map[string]*rate.Limiter), + } +} + +func (r *RateLimitingRoundTripper) getLimiter(host string) *rate.Limiter { + r.mu.Lock() + defer r.mu.Unlock() + + limiter, exists := r.limiters[host] + if !exists { + rateLimit := r.config.GetRate(host) + limiter = rate.NewLimiter(rate.Limit(rateLimit.RequestsPerSecond), rateLimit.Burst) + r.limiters[host] = limiter + } + return limiter +} + +// RoundTrip executes a single HTTP transaction, waiting for a token from the bucket first. +func (r *RateLimitingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + limiter := r.getLimiter(req.URL.Hostname()) + err := limiter.Wait(req.Context()) + if err != nil { + return nil, err + } + + resp, err := r.next.RoundTrip(req) + if err != nil { + return nil, err + } + + if resp.StatusCode == http.StatusTooManyRequests { + retryAfter := resp.Header.Get("Retry-After") + var delay time.Duration + + // Retry-After can be in seconds or an HTTP-date. + if seconds, err := strconv.Atoi(retryAfter); err == nil { + delay = time.Duration(seconds) * time.Second + } else if t, err := http.ParseTime(retryAfter); err == nil { + delay = time.Until(t) + } else { + // No valid Retry-After header, use a default backoff. + delay = time.Second * 5 + } + + // Close the response body of the 429 response to allow the transport to reuse the connection. + if resp.Body != nil { + resp.Body.Close() + } + + // Wait and retry the request once. + time.Sleep(delay) + return r.next.RoundTrip(req) + } + + return resp, nil +} diff --git a/pkg/http/testdata/.borg-rates.yaml b/pkg/http/testdata/.borg-rates.yaml new file mode 100644 index 0000000..4dda321 --- /dev/null +++ b/pkg/http/testdata/.borg-rates.yaml @@ -0,0 +1,21 @@ +defaults: + requests_per_second: 1 + burst: 5 + +domains: + api.github.com: + requests_per_second: 0.5 # 1 req per 2 seconds + burst: 1 + + bitcointalk.org: + requests_per_second: 0.2 # 1 req per 5 seconds + burst: 1 + reason: "aggressive anti-scraping" + + eprint.iacr.org: + requests_per_second: 2 + burst: 10 + + "*.archive.org": + requests_per_second: 1 + burst: 3 diff --git a/pkg/vcs/git.go b/pkg/vcs/git.go index 92e20aa..b6e88d2 100644 --- a/pkg/vcs/git.go +++ b/pkg/vcs/git.go @@ -2,12 +2,15 @@ package vcs import ( "io" + "net/http" "os" "path/filepath" + "sync" "github.com/Snider/Borg/pkg/datanode" "github.com/go-git/go-git/v5" + githttp "github.com/go-git/go-git/v5/plumbing/transport/http" ) // GitCloner is an interface for cloning Git repositories. @@ -15,12 +18,26 @@ type GitCloner interface { CloneGitRepository(repoURL string, progress io.Writer) (*datanode.DataNode, error) } -// NewGitCloner creates a new GitCloner. +// NewGitCloner creates a new GitCloner with the default http client. func NewGitCloner() GitCloner { - return &gitCloner{} + return NewGitClonerWithClient(http.DefaultClient) } -type gitCloner struct{} +// NewGitClonerWithClient creates a new GitCloner with a custom http.Client. +func NewGitClonerWithClient(client *http.Client) GitCloner { + if client == nil { + client = http.DefaultClient + } + return &gitCloner{ + httpClient: client, + } +} + +type gitCloner struct { + httpClient *http.Client +} + +var cloneMutex = &sync.Mutex{} // CloneGitRepository clones a Git repository from a URL and packages it into a DataNode. func (g *gitCloner) CloneGitRepository(repoURL string, progress io.Writer) (*datanode.DataNode, error) { @@ -37,6 +54,14 @@ func (g *gitCloner) CloneGitRepository(repoURL string, progress io.Writer) (*dat cloneOptions.Progress = progress } + cloneMutex.Lock() + originalClient := githttp.DefaultClient + githttp.DefaultClient = githttp.NewClient(g.httpClient) + defer func() { + githttp.DefaultClient = originalClient + cloneMutex.Unlock() + }() + _, err = git.PlainClone(tempPath, false, cloneOptions) if err != nil { if err.Error() == "remote repository is empty" { diff --git a/pkg/website/website.go b/pkg/website/website.go index b2bd517..ffbc05c 100644 --- a/pkg/website/website.go +++ b/pkg/website/website.go @@ -43,13 +43,17 @@ func NewDownloaderWithClient(maxDepth int, client *http.Client) *Downloader { } // downloadAndPackageWebsite downloads a website and packages it into a DataNode. -func downloadAndPackageWebsite(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) { +func downloadAndPackageWebsite(startURL string, maxDepth int, bar *progressbar.ProgressBar, client *http.Client) (*datanode.DataNode, error) { baseURL, err := url.Parse(startURL) if err != nil { return nil, err } - d := NewDownloader(maxDepth) + if client == nil { + client = http.DefaultClient + } + + d := NewDownloaderWithClient(maxDepth, client) d.baseURL = baseURL d.progressBar = bar d.crawl(startURL, 0) diff --git a/pkg/website/website_test.go b/pkg/website/website_test.go index d3685e5..90c7ec9 100644 --- a/pkg/website/website_test.go +++ b/pkg/website/website_test.go @@ -20,7 +20,7 @@ func TestDownloadAndPackageWebsite_Good(t *testing.T) { defer server.Close() bar := progressbar.NewOptions(1, progressbar.OptionSetWriter(io.Discard)) - dn, err := DownloadAndPackageWebsite(server.URL, 2, bar) + dn, err := DownloadAndPackageWebsite(server.URL, 2, bar, nil) if err != nil { t.Fatalf("DownloadAndPackageWebsite failed: %v", err) } @@ -52,7 +52,7 @@ func TestDownloadAndPackageWebsite_Good(t *testing.T) { func TestDownloadAndPackageWebsite_Bad(t *testing.T) { t.Run("Invalid Start URL", func(t *testing.T) { - _, err := DownloadAndPackageWebsite("http://invalid-url", 1, nil) + _, err := DownloadAndPackageWebsite("http://invalid-url", 1, nil, nil) if err == nil { t.Fatal("Expected an error for an invalid start URL, but got nil") } @@ -63,7 +63,7 @@ func TestDownloadAndPackageWebsite_Bad(t *testing.T) { http.Error(w, "Internal Server Error", http.StatusInternalServerError) })) defer server.Close() - _, err := DownloadAndPackageWebsite(server.URL, 1, nil) + _, err := DownloadAndPackageWebsite(server.URL, 1, nil, nil) if err == nil { t.Fatal("Expected an error for a server error on the start URL, but got nil") } @@ -80,7 +80,7 @@ func TestDownloadAndPackageWebsite_Bad(t *testing.T) { })) defer server.Close() // We expect an error because the link is broken. - dn, err := DownloadAndPackageWebsite(server.URL, 1, nil) + dn, err := DownloadAndPackageWebsite(server.URL, 1, nil, nil) if err == nil { t.Fatal("Expected an error for a broken link, but got nil") } @@ -99,7 +99,7 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) { defer server.Close() bar := progressbar.NewOptions(1, progressbar.OptionSetWriter(io.Discard)) - dn, err := DownloadAndPackageWebsite(server.URL, 1, bar) // Max depth of 1 + dn, err := DownloadAndPackageWebsite(server.URL, 1, bar, nil) // Max depth of 1 if err != nil { t.Fatalf("DownloadAndPackageWebsite failed: %v", err) } @@ -122,7 +122,7 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) { fmt.Fprint(w, `External`) })) defer server.Close() - dn, err := DownloadAndPackageWebsite(server.URL, 1, nil) + dn, err := DownloadAndPackageWebsite(server.URL, 1, nil, nil) if err != nil { t.Fatalf("DownloadAndPackageWebsite failed: %v", err) } @@ -156,7 +156,7 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) { // For now, we'll just test that it doesn't hang forever. done := make(chan bool) go func() { - _, err := DownloadAndPackageWebsite(server.URL, 1, nil) + _, err := DownloadAndPackageWebsite(server.URL, 1, nil, nil) if err != nil && !strings.Contains(err.Error(), "context deadline exceeded") { // We expect a timeout error, but other errors are failures. t.Errorf("unexpected error: %v", err)