Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions cmd/all_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func TestAllCmd_Good(t *testing.T) {
},
})
oldNewAuthenticatedClient := github.NewAuthenticatedClient
github.NewAuthenticatedClient = func(ctx context.Context) *http.Client {
github.NewAuthenticatedClient = func(ctx context.Context, baseClient *http.Client) *http.Client {
return mockGithubClient
}
defer func() {
Expand Down Expand Up @@ -67,7 +67,7 @@ func TestAllCmd_Bad(t *testing.T) {
},
})
oldNewAuthenticatedClient := github.NewAuthenticatedClient
github.NewAuthenticatedClient = func(ctx context.Context) *http.Client {
github.NewAuthenticatedClient = func(ctx context.Context, baseClient *http.Client) *http.Client {
return mockGithubClient
}
defer func() {
Expand Down Expand Up @@ -96,7 +96,7 @@ func TestAllCmd_Ugly(t *testing.T) {
},
})
oldNewAuthenticatedClient := github.NewAuthenticatedClient
github.NewAuthenticatedClient = func(ctx context.Context) *http.Client {
github.NewAuthenticatedClient = func(ctx context.Context, baseClient *http.Client) *http.Client {
return mockGithubClient
}
defer func() {
Expand Down
52 changes: 51 additions & 1 deletion cmd/collect_github_repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@ package cmd
import (
"fmt"
"io"
"net/http"
"os"
"strconv"
"strings"

"github.com/Snider/Borg/pkg/compress"
borghttp "github.com/Snider/Borg/pkg/http"
"github.com/Snider/Borg/pkg/tim"
"github.com/Snider/Borg/pkg/trix"
"github.com/Snider/Borg/pkg/ui"
Expand Down Expand Up @@ -36,6 +40,9 @@ func NewCollectGithubRepoCmd() *cobra.Command {
format, _ := cmd.Flags().GetString("format")
compression, _ := cmd.Flags().GetString("compression")
password, _ := cmd.Flags().GetString("password")
rateLimit, _ := cmd.Flags().GetString("rate-limit")
burst, _ := cmd.Flags().GetInt("burst")
rateConfig, _ := cmd.Flags().GetString("rate-config")

if format != "datanode" && format != "tim" && format != "trix" && format != "stim" {
return fmt.Errorf("invalid format: %s (must be 'datanode', 'tim', 'trix', or 'stim')", format)
Expand All @@ -44,6 +51,46 @@ func NewCollectGithubRepoCmd() *cobra.Command {
return fmt.Errorf("invalid compression: %s (must be 'none', 'gz', or 'xz')", compression)
}

config := &borghttp.Config{
Defaults: borghttp.Rate{
RequestsPerSecond: 1, // GitHub API has strict limits
Burst: 1,
},
Domains: make(map[string]borghttp.Rate),
}

if rateConfig != "" {
var err error
config, err = borghttp.ParseConfig(rateConfig)
if err != nil {
return fmt.Errorf("error parsing rate config: %w", err)
}
}

if rateLimit != "" {
parts := strings.Split(rateLimit, "/")
if len(parts) != 2 || (parts[1] != "s" && parts[1] != "m") {
return fmt.Errorf("invalid rate limit format: %s (e.g., 2/s or 120/m)", rateLimit)
}
rate, err := strconv.ParseFloat(parts[0], 64)
if err != nil {
return fmt.Errorf("invalid rate: %w", err)
}
if parts[1] == "m" {
rate = rate / 60
}
config.Defaults.RequestsPerSecond = rate
}

if burst > 0 {
config.Defaults.Burst = burst
}
Comment on lines +54 to +87

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for parsing rate-limiting flags and constructing the borghttp.Config is duplicated in cmd/collect_github_repo.go, cmd/collect_github_repos.go, and cmd/collect_website.go.

To improve maintainability and reduce code duplication, consider refactoring this logic into a shared helper function. This function could take the command's flag set and default rate/burst values, and return a configured *borghttp.Config or an error.


client := &http.Client{
Transport: borghttp.NewRateLimitingRoundTripper(config, http.DefaultTransport),
}
cloner := vcs.NewGitClonerWithClient(client)

prompter := ui.NewNonInteractivePrompter(ui.GetVCSQuote)
prompter.Start()
defer prompter.Stop()
Expand All @@ -54,7 +101,7 @@ func NewCollectGithubRepoCmd() *cobra.Command {
progressWriter = ui.NewProgressWriter(bar)
}

dn, err := GitCloner.CloneGitRepository(repoURL, progressWriter)
dn, err := cloner.CloneGitRepository(repoURL, progressWriter)
if err != nil {
return fmt.Errorf("error cloning repository: %w", err)
}
Expand Down Expand Up @@ -118,6 +165,9 @@ func NewCollectGithubRepoCmd() *cobra.Command {
cmd.Flags().String("format", "datanode", "Output format (datanode, tim, trix, or stim)")
cmd.Flags().String("compression", "none", "Compression format (none, gz, or xz)")
cmd.Flags().String("password", "", "Password for encryption (required for trix/stim)")
cmd.Flags().String("rate-limit", "", "Requests per second (e.g., 2/s) or minute (e.g., 120/m)")
cmd.Flags().Int("burst", 0, "Burst allowance")
cmd.Flags().String("rate-config", "", "Path to a rate limit configuration file")
return cmd
}

Expand Down
30 changes: 27 additions & 3 deletions cmd/collect_github_repo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (

"github.com/Snider/Borg/pkg/datanode"
"github.com/Snider/Borg/pkg/mocks"
"github.com/Snider/Borg/pkg/vcs"
"github.com/spf13/cobra"
)

func TestCollectGithubRepoCmd_Good(t *testing.T) {
Expand All @@ -22,7 +24,12 @@ func TestCollectGithubRepoCmd_Good(t *testing.T) {
}()

rootCmd := NewRootCmd()
rootCmd.AddCommand(GetCollectCmd())
collectCmd := NewCollectCmd()
githubCmd := GetCollectGithubCmd()
repoCmd := NewCollectGithubRepoCmd()
githubCmd.AddCommand(repoCmd)
collectCmd.AddCommand(githubCmd)
rootCmd.AddCommand(collectCmd)

// Execute command
out := filepath.Join(t.TempDir(), "out")
Expand All @@ -45,7 +52,12 @@ func TestCollectGithubRepoCmd_Bad(t *testing.T) {
}()

rootCmd := NewRootCmd()
rootCmd.AddCommand(GetCollectCmd())
collectCmd := NewCollectCmd()
githubCmd := GetCollectGithubCmd()
repoCmd := NewCollectGithubRepoCmd()
githubCmd.AddCommand(repoCmd)
collectCmd.AddCommand(githubCmd)
rootCmd.AddCommand(collectCmd)

// Execute command
out := filepath.Join(t.TempDir(), "out")
Expand All @@ -58,7 +70,19 @@ func TestCollectGithubRepoCmd_Bad(t *testing.T) {
func TestCollectGithubRepoCmd_Ugly(t *testing.T) {
t.Run("Invalid repo URL", func(t *testing.T) {
rootCmd := NewRootCmd()
rootCmd.AddCommand(GetCollectCmd())
collectCmd := NewCollectCmd()
githubCmd := GetCollectGithubCmd()
repoCmd := NewCollectGithubRepoCmd()
githubCmd.AddCommand(repoCmd)
collectCmd.AddCommand(githubCmd)
rootCmd.AddCommand(collectCmd)

repoCmd.RunE = func(cmd *cobra.Command, args []string) error {
cloner := vcs.NewGitClonerWithClient(nil)
_, err := cloner.CloneGitRepository(args[0], nil)
return err
}

_, err := executeCommand(rootCmd, "collect", "github", "repo", "not-a-github-url")
if err == nil {
t.Fatal("expected an error for invalid repo URL, but got none")
Expand Down
55 changes: 53 additions & 2 deletions cmd/collect_github_repos.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,70 @@ package cmd

import (
"fmt"
"net/http"
"strconv"
"strings"

"github.com/Snider/Borg/pkg/github"
borghttp "github.com/Snider/Borg/pkg/http"
"github.com/spf13/cobra"
)

var (
// GithubClient is the github client used by the command. It can be replaced for testing.
GithubClient = github.NewGithubClient()
GithubClient = github.NewGithubClient(nil)
)

var collectGithubReposCmd = &cobra.Command{
Use: "repos [user-or-org]",
Short: "Collects all public repositories for a user or organization",
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
repos, err := GithubClient.GetPublicRepos(cmd.Context(), args[0])
rateLimit, _ := cmd.Flags().GetString("rate-limit")
burst, _ := cmd.Flags().GetInt("burst")
rateConfig, _ := cmd.Flags().GetString("rate-config")

config := &borghttp.Config{
Defaults: borghttp.Rate{
RequestsPerSecond: 1, // GitHub API has strict limits
Burst: 1,
},
Domains: make(map[string]borghttp.Rate),
}

if rateConfig != "" {
var err error
config, err = borghttp.ParseConfig(rateConfig)
if err != nil {
return fmt.Errorf("error parsing rate config: %w", err)
}
}

if rateLimit != "" {
parts := strings.Split(rateLimit, "/")
if len(parts) != 2 || (parts[1] != "s" && parts[1] != "m") {
return fmt.Errorf("invalid rate limit format: %s (e.g., 2/s or 120/m)", rateLimit)
}
rate, err := strconv.ParseFloat(parts[0], 64)
if err != nil {
return fmt.Errorf("invalid rate: %w", err)
}
if parts[1] == "m" {
rate = rate / 60
}
config.Defaults.RequestsPerSecond = rate
}

if burst > 0 {
config.Defaults.Burst = burst
}

client := &http.Client{
Transport: borghttp.NewRateLimitingRoundTripper(config, http.DefaultTransport),
}
ghClient := github.NewGithubClient(client)

repos, err := ghClient.GetPublicRepos(cmd.Context(), args[0])
if err != nil {
return err
}
Expand All @@ -30,4 +78,7 @@ var collectGithubReposCmd = &cobra.Command{

func init() {
collectGithubCmd.AddCommand(collectGithubReposCmd)
collectGithubReposCmd.Flags().String("rate-limit", "", "Requests per second (e.g., 2/s) or minute (e.g., 120/m)")
collectGithubReposCmd.Flags().Int("burst", 0, "Burst allowance")
collectGithubReposCmd.Flags().String("rate-config", "", "Path to a rate limit configuration file")
}
52 changes: 50 additions & 2 deletions cmd/collect_website.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@ package cmd

import (
"fmt"
"net/http"
"os"
"strconv"
"strings"

"github.com/schollz/progressbar/v3"
"github.com/Snider/Borg/pkg/compress"
borghttp "github.com/Snider/Borg/pkg/http"
"github.com/Snider/Borg/pkg/tim"
"github.com/Snider/Borg/pkg/trix"
"github.com/Snider/Borg/pkg/ui"
"github.com/Snider/Borg/pkg/website"

"github.com/spf13/cobra"
)

Expand Down Expand Up @@ -38,11 +41,53 @@ func NewCollectWebsiteCmd() *cobra.Command {
format, _ := cmd.Flags().GetString("format")
compression, _ := cmd.Flags().GetString("compression")
password, _ := cmd.Flags().GetString("password")
rateLimit, _ := cmd.Flags().GetString("rate-limit")
burst, _ := cmd.Flags().GetInt("burst")
rateConfig, _ := cmd.Flags().GetString("rate-config")

if format != "datanode" && format != "tim" && format != "trix" {
return fmt.Errorf("invalid format: %s (must be 'datanode', 'tim', or 'trix')", format)
}

config := &borghttp.Config{
Defaults: borghttp.Rate{
RequestsPerSecond: 10, // A reasonable default
Burst: 10,
},
Domains: make(map[string]borghttp.Rate),
}

if rateConfig != "" {
var err error
config, err = borghttp.ParseConfig(rateConfig)
if err != nil {
return fmt.Errorf("error parsing rate config: %w", err)
}
}

if rateLimit != "" {
parts := strings.Split(rateLimit, "/")
if len(parts) != 2 || (parts[1] != "s" && parts[1] != "m") {
return fmt.Errorf("invalid rate limit format: %s (e.g., 2/s or 120/m)", rateLimit)
}
rate, err := strconv.ParseFloat(parts[0], 64)
if err != nil {
return fmt.Errorf("invalid rate: %w", err)
}
if parts[1] == "m" {
rate = rate / 60
}
config.Defaults.RequestsPerSecond = rate
}

if burst > 0 {
config.Defaults.Burst = burst
}

client := &http.Client{
Transport: borghttp.NewRateLimitingRoundTripper(config, http.DefaultTransport),
}

prompter := ui.NewNonInteractivePrompter(ui.GetWebsiteQuote)
prompter.Start()
defer prompter.Stop()
Expand All @@ -51,7 +96,7 @@ func NewCollectWebsiteCmd() *cobra.Command {
bar = ui.NewProgressBar(-1, "Crawling website")
}

dn, err := website.DownloadAndPackageWebsite(websiteURL, depth, bar)
dn, err := website.DownloadAndPackageWebsite(websiteURL, depth, bar, client)
if err != nil {
return fmt.Errorf("error downloading and packaging website: %w", err)
}
Expand Down Expand Up @@ -104,5 +149,8 @@ func NewCollectWebsiteCmd() *cobra.Command {
collectWebsiteCmd.PersistentFlags().String("format", "datanode", "Output format (datanode, tim, or trix)")
collectWebsiteCmd.PersistentFlags().String("compression", "none", "Compression format (none, gz, or xz)")
collectWebsiteCmd.PersistentFlags().String("password", "", "Password for encryption")
collectWebsiteCmd.Flags().String("rate-limit", "", "Requests per second (e.g., 2/s) or minute (e.g., 120/m)")
collectWebsiteCmd.Flags().Int("burst", 0, "Burst allowance")
collectWebsiteCmd.Flags().String("rate-config", "", "Path to a rate limit configuration file")
return collectWebsiteCmd
}
Loading
Loading