From a0aaf0d215290c29800657b0cc7dd6f2338a9f72 Mon Sep 17 00:00:00 2001 From: Zhang JL Date: Tue, 21 Jan 2025 22:06:44 +0800 Subject: [PATCH 1/8] feat(github_release): add GitHub Release driver --- drivers/all.go | 1 + drivers/github_release/backoff.go | 45 +++ drivers/github_release/backoff_test.go | 37 ++ drivers/github_release/driver.go | 168 +++++++++ drivers/github_release/github.go | 218 +++++++++++ drivers/github_release/meta.go | 34 ++ drivers/github_release/types.go | 253 +++++++++++++ drivers/github_release/types_test.go | 477 +++++++++++++++++++++++++ drivers/github_release/util.go | 1 + 9 files changed, 1234 insertions(+) create mode 100644 drivers/github_release/backoff.go create mode 100644 drivers/github_release/backoff_test.go create mode 100644 drivers/github_release/driver.go create mode 100644 drivers/github_release/github.go create mode 100644 drivers/github_release/meta.go create mode 100644 drivers/github_release/types.go create mode 100644 drivers/github_release/types_test.go create mode 100644 drivers/github_release/util.go diff --git a/drivers/all.go b/drivers/all.go index 8b253a08558..1fd4fdd0a73 100644 --- a/drivers/all.go +++ b/drivers/all.go @@ -25,6 +25,7 @@ import ( _ "github.com/alist-org/alist/v3/drivers/febbox" _ "github.com/alist-org/alist/v3/drivers/ftp" _ "github.com/alist-org/alist/v3/drivers/github" + _ "github.com/alist-org/alist/v3/drivers/github_release" _ "github.com/alist-org/alist/v3/drivers/google_drive" _ "github.com/alist-org/alist/v3/drivers/google_photo" _ "github.com/alist-org/alist/v3/drivers/halalcloud" diff --git a/drivers/github_release/backoff.go b/drivers/github_release/backoff.go new file mode 100644 index 00000000000..f6edb2c1cae --- /dev/null +++ b/drivers/github_release/backoff.go @@ -0,0 +1,45 @@ +package template + +import ( + "math/rand" + "time" +) + +const ( + initialRetryInterval = 500 * time.Millisecond + maxInterval = 1 * time.Minute + maxElapsedTime = 15 * time.Minute + randomizationFactor = 0.5 + multiplier = 1.5 +) + +// Backoff 提供了确定在重试操作之前等待的时间算法 +type Backoff struct { + interval time.Duration + elapsedTime time.Duration +} + +// Pause 返回重试操作之前等待的时间量,如果可以再次尝试则返回 true,否则返回 false,表示操作应该被放弃。 +func (b *Backoff) Pause() (time.Duration, bool) { + if b.interval == 0 { + // first time + b.interval = initialRetryInterval + b.elapsedTime = 0 + } + + // interval from [1 - randomizationFactor, 1 + randomizationFactor) + randomizedInterval := time.Duration((rand.Float64()*(2*randomizationFactor) + (1 - randomizationFactor)) * float64(b.interval)) + b.elapsedTime += randomizedInterval + + if b.elapsedTime > maxElapsedTime { + return 0, false + } + + // 将间隔增加到间隔上限 + b.interval = time.Duration(float64(b.interval) * multiplier) + if b.interval > maxInterval { + b.interval = maxInterval + } + + return randomizedInterval, true +} diff --git a/drivers/github_release/backoff_test.go b/drivers/github_release/backoff_test.go new file mode 100644 index 00000000000..f6f088ef6b3 --- /dev/null +++ b/drivers/github_release/backoff_test.go @@ -0,0 +1,37 @@ +package template + +import ( + "testing" + "time" +) + +func TestBackoffMultiple(t *testing.T) { + b := &Backoff{} + for i := 0; i < 19; i++ { + p, ok := b.Pause() + t.Logf("iteration %d pausing for %s", i, p) + if !ok { + t.Fatalf("hit the pause timeout after %d pauses", i) + } + } +} + +func TestBackoffTimeout(t *testing.T) { + var elapsed time.Duration + b := &Backoff{} + for i := 0; i < 40; i++ { + p, ok := b.Pause() + elapsed += p + t.Logf("iteration %d pausing for %s (total %s)", i, p, elapsed) + if !ok { + break + } + } + if _, ok := b.Pause(); ok { + t.Fatalf("did not hit the pause timeout") + } + + if elapsed > maxElapsedTime { + t.Fatalf("waited too long: %s > %s", elapsed, maxElapsedTime) + } +} diff --git a/drivers/github_release/driver.go b/drivers/github_release/driver.go new file mode 100644 index 00000000000..bc7427edb20 --- /dev/null +++ b/drivers/github_release/driver.go @@ -0,0 +1,168 @@ +package template + +import ( + "context" + "fmt" + "net/http" + "strconv" + "strings" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/pkg/errors" +) + +type GithubRelease struct { + model.Storage + Addition + + api *ApiContext + repo repository +} + +func (d *GithubRelease) Config() driver.Config { + return config +} + +func (d *GithubRelease) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *GithubRelease) Init(ctx context.Context) error { + token := d.Addition.Token + if token == "" { + return errs.EmptyToken + } + + if d.Addition.MaxReleases < 1 { + return errors.New("max_releases must be greater than 0") + } + + if d.Addition.MaxReleases > 100 { + d.Addition.MaxReleases = 100 + } + + d.api = NewApiContext(token, nil) + + repo, err := newRepository(d.Addition.Repo) + if err != nil { + return err + } + d.repo = repo + + return nil +} + +// Drop Delete this driver +func (d *GithubRelease) Drop(ctx context.Context) error { + return nil +} + +func (d *GithubRelease) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + repo, err := newRepository(d.Addition.Repo) + if err != nil { + return nil, err + } + + // 判断 dir 是不是挂在点。如果 dir 是挂载点,则返回所有的 release; + // 如果 dir 不是挂载点,则返回 dir 下的 release。 + if dir.GetPath() == "" { + releases, err := d.api.GetReleases(repo, d.Addition.MaxReleases) + if err != nil { + return nil, err + } + return releases, nil + } + + idStr := dir.GetID() + id, err := strconv.ParseInt(idStr, 10, 64) + if err != nil { + return nil, errors.Wrapf(err, "list release %s failed, id is not a number", idStr) + } + + release, err := d.api.GetRelease(repo, id) + if err != nil { + return nil, err + } + + return release.Children() +} + +func (d *GithubRelease) proxyDownload(file model.Obj, args model.LinkArgs) bool { + if d.Config().MustProxy() || d.GetStorage().WebProxy { + return true + } + + req := args.HttpReq + if args.HttpReq != nil && + req.URL != nil && + strings.HasPrefix(req.URL.Path, fmt.Sprintf("/p%s", d.GetStorage().MountPath)) { + return true + } + + return false +} + +func (d *GithubRelease) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + idStr := file.GetID() + id, err := strconv.ParseInt(idStr, 10, 64) + if err != nil { + return nil, errors.Wrapf(err, "get link of file %s failed, id is not a number", idStr) + } + asset, err := d.api.GetReleaseAsset(d.repo, id) + if err != nil { + return nil, err + } + + if d.proxyDownload(file, args) { + + header := http.Header{ + "User-Agent": {"Alist/" + conf.VERSION}, + "Accept": {"application/octet-stream"}, + } + d.api.SetAuthHeader(header) + + return &model.Link{ + URL: asset.URL, + Header: header, + }, nil + } + + return &model.Link{ + URL: asset.BrowserDownloadURL, + }, nil + +} + +func (d *GithubRelease) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) (model.Obj, error) { + return nil, errs.NotSupport +} + +func (d *GithubRelease) Move(ctx context.Context, srcObj, dstDir model.Obj) (model.Obj, error) { + return nil, errs.NotSupport +} + +func (d *GithubRelease) Rename(ctx context.Context, srcObj model.Obj, newName string) (model.Obj, error) { + // TODO rename obj, optional + return nil, errs.NotImplement +} + +func (d *GithubRelease) Copy(ctx context.Context, srcObj, dstDir model.Obj) (model.Obj, error) { + return nil, errs.NotSupport +} + +func (d *GithubRelease) Remove(ctx context.Context, obj model.Obj) error { + return errs.NotSupport +} + +func (d *GithubRelease) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { + return nil, errs.NotSupport +} + +//func (d *Template) Other(ctx context.Context, args model.OtherArgs) (interface{}, error) { +// return nil, errs.NotSupport +//} + +var _ driver.Driver = (*GithubRelease)(nil) diff --git a/drivers/github_release/github.go b/drivers/github_release/github.go new file mode 100644 index 00000000000..9b532d0c6dd --- /dev/null +++ b/drivers/github_release/github.go @@ -0,0 +1,218 @@ +package template + +import ( + "fmt" + "io" + "net/http" + "time" + + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/pkg/errors" +) + +const GITHUB_API_VERSION = "2022-11-28" + +type ApiContext struct { + token string + version string + client *http.Client +} + +func NewApiContext(token string, client *http.Client) *ApiContext { + ret := ApiContext{ + token: token, + version: GITHUB_API_VERSION, + client: client, + } + + if ret.client == nil { + ret.client = http.DefaultClient + } + + return &ret +} + +// parseHTTPError 解析 HTTP 错误. +func parseHTTPError(body []byte) error { + var v map[string]interface{} + err := utils.Json.Unmarshal(body, &v) + if err != nil { + return errors.New(string(body)) + } + + iface, ok := v["message"] + if !ok { + return errors.New(string(body)) + } + + message, ok := iface.(string) + if !ok { + return errors.New(string(body)) + } + + return errors.New(message) +} + +// getWithRetry 获取 GitHub API 并重试. +func (a *ApiContext) getWithRetry(url string) (*http.Response, error) { + backoff := Backoff{} + + for { + response, err := a.get(url) + + // non-2xx code does not cause error + if err != nil { + // retry when error is not nil + p, retryAgain := backoff.Pause() + if !retryAgain { + return nil, errors.Wrap(err, "request failed") + } + utils.Log.Debugf("query github api error: %s, retry after %s", err, p) + time.Sleep(p) + continue + } + + // defensive check + if response == nil { + utils.Log.Errorf("query github api error: %s, will not retry", err) + return nil, errors.New("request failed: response is nil") + } + + if response.StatusCode == http.StatusOK { + return response, nil + } + + // We won't return the response to the caller here, but it's still better to read the response.Body completely even if we don't use it. + // see https://pkg.go.dev/net/http#Client.Do + body, err := io.ReadAll(response.Body) + if err != nil { + return nil, errors.Wrap(err, "failed to read response body") + } + + if response.StatusCode >= 500 && response.StatusCode <= 599 { + // retry when server error + p, retryAgain := backoff.Pause() + if !retryAgain { + return nil, parseHTTPError(body) + } + utils.Log.Debugf("query github api error: status code %d, retry after %s", response.StatusCode, p) + time.Sleep(p) + continue + } + + return nil, parseHTTPError(body) + } +} + +// SetAuthHeader 为请求头添加 GitHub API 所需的认证头. +// 这是一个副作用函数, 会直接修改传入的 header. +func (a *ApiContext) SetAuthHeader(header http.Header) { + header.Set("Authorization", fmt.Sprintf("Bearer %s", a.token)) +} + +// get 获取 GitHub API. +func (a *ApiContext) get(url string) (*http.Response, error) { + request, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, err + } + + request.Header.Set("Accept", "application/vnd.github+json") + a.SetAuthHeader(request.Header) + + response, err := a.client.Do(request) + if err != nil { + return nil, err + } + + return response, nil +} + +// GetReleases 获取仓库信息. +func (a *ApiContext) GetReleases(repo repository, perPage int) ([]model.Obj, error) { + if perPage < 1 { + perPage = 30 + } + url := fmt.Sprintf("https://api.github.com/repos/%s/releases?per_page=%d", repo.UrlEncode(), perPage) + response, err := a.getWithRetry(url) + if err != nil { + return nil, err + } + defer response.Body.Close() + + body, err := io.ReadAll(response.Body) + if err != nil { + return nil, errors.Wrap(err, "failed to read response body") + } + + if response.StatusCode != http.StatusOK { + return nil, parseHTTPError(body) + } + + releases := []Release{} + err = utils.Json.Unmarshal(body, &releases) + if err != nil { + return nil, errors.Wrap(err, "failed to unmarshal releases") + } + + tree := make([]model.Obj, 0, len(releases)) + for _, release := range releases { + tree = append(tree, &release) + } + return tree, nil +} + +// GetRelease 获取指定 tag 的 release. +func (a *ApiContext) GetRelease(repo repository, id int64) (*Release, error) { + url := fmt.Sprintf("https://api.github.com/repos/%s/releases/%d", repo.UrlEncode(), id) + response, err := a.getWithRetry(url) + if err != nil { + return nil, err + } + defer response.Body.Close() + + body, err := io.ReadAll(response.Body) + if err != nil { + return nil, errors.Wrap(err, "failed to read response body") + } + + if response.StatusCode != http.StatusOK { + return nil, parseHTTPError(body) + } + + release := Release{} + err = utils.Json.Unmarshal(body, &release) + if err != nil { + return nil, errors.Wrap(err, "failed to unmarshal release") + } + + return &release, nil +} + +// GetReleaseAsset 获取指定 tag 的 release 的 assets. +func (a *ApiContext) GetReleaseAsset(repo repository, ID int64) (*Asset, error) { + url := fmt.Sprintf("https://api.github.com/repos/%s/releases/assets/%d", repo.UrlEncode(), ID) + response, err := a.getWithRetry(url) + if err != nil { + return nil, err + } + defer response.Body.Close() + + body, err := io.ReadAll(response.Body) + if err != nil { + return nil, errors.Wrap(err, "failed to read response body") + } + + if response.StatusCode != http.StatusOK { + return nil, parseHTTPError(body) + } + + asset := Asset{} + err = utils.Json.Unmarshal(body, &asset) + if err != nil { + return nil, errors.Wrap(err, "failed to unmarshal asset") + } + + return &asset, nil +} diff --git a/drivers/github_release/meta.go b/drivers/github_release/meta.go new file mode 100644 index 00000000000..f925972391c --- /dev/null +++ b/drivers/github_release/meta.go @@ -0,0 +1,34 @@ +package template + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + driver.RootID + // define other + Repo string `json:"repo" required:"true" default:"AlistGo/alist"` + Token string `json:"token" required:"true" default:""` + MaxReleases int `json:"max_releases" required:"true" type:"number" default:"30" help:"max releases to list"` +} + +var config = driver.Config{ + Name: "Github Release", + LocalSort: false, + OnlyLocal: false, + OnlyProxy: false, + NoCache: false, + NoUpload: true, + NeedMs: false, + DefaultRoot: "0", + CheckStatus: false, + Alert: "", + NoOverwriteUpload: false, +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &GithubRelease{} + }) +} diff --git a/drivers/github_release/types.go b/drivers/github_release/types.go new file mode 100644 index 00000000000..ca6205ec114 --- /dev/null +++ b/drivers/github_release/types.go @@ -0,0 +1,253 @@ +package template + +import ( + "fmt" + "net/url" + "regexp" + "strings" + "time" + + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/pkg/errors" +) + +type repository struct { + owner string + name string +} + +func newRepository(name string) (repository, error) { + parts := strings.Split(name, "/") + if len(parts) != 2 || parts[0] == "" || parts[1] == "" { + return repository{}, errors.New("repo name must be in the format of owner/repo") + } + + return repository{ + owner: parts[0], + name: parts[1], + }, nil +} + +func (r *repository) String() string { + return fmt.Sprintf("%s/%s", r.owner, r.name) +} + +func (r *repository) UrlEncode() string { + ownerPart := url.QueryEscape(r.owner) + namePart := url.QueryEscape(r.name) + return fmt.Sprintf("%s/%s", ownerPart, namePart) +} + +type Release struct { + URL string `json:"url"` + HTMLURL string `json:"html_url"` + AssetsURL string `json:"assets_url"` + UploadURL string `json:"upload_url"` + TarballURL string `json:"tarball_url"` + ZipballURL string `json:"zipball_url"` + ID int64 `json:"id"` + NodeID string `json:"node_id"` + TagName string `json:"tag_name"` + TargetCommitish string `json:"target_commitish"` + Name string `json:"name"` + Body string `json:"body"` + Draft bool `json:"draft"` + Prerelease bool `json:"prerelease"` + CreatedAt time.Time `json:"created_at"` + PublishedAt time.Time `json:"published_at"` + Author User `json:"author"` + Assets []Asset `json:"assets"` + BodyHTML string `json:"body_html"` + BodyText string `json:"body_text"` + MentionsCount int `json:"mentions_count"` + DiscussionURL string `json:"discussion_url"` +} + +func (r *Release) UnmarshalJSON(data []byte) error { + type alias Release + aux := struct { + CreatedAt string `json:"created_at"` + PublishedAt string `json:"published_at"` + *alias + }{ + alias: (*alias)(r), + } + + if err := utils.Json.Unmarshal(data, &aux); err != nil { + return errors.Wrap(err, "failed to unmarshal release") + } + + createdAt, err := time.Parse(time.RFC3339, aux.CreatedAt) + if err != nil { + utils.Log.Error("failed to parse created_at in release", "error", err) + createdAt = time.Time{} + } else { + r.CreatedAt = createdAt + } + + publishedAt, err := time.Parse(time.RFC3339, aux.PublishedAt) + if err != nil { + utils.Log.Error("failed to parse published_at in release", "error", err) + publishedAt = time.Time{} + } else { + r.PublishedAt = publishedAt + } + + return nil +} + +func (r *Release) GetSize() int64 { + return 0 +} + +func (r *Release) GetName() string { + return r.TagName +} + +func (r *Release) ModTime() time.Time { + return r.PublishedAt +} + +func (r *Release) CreateTime() time.Time { + return r.CreatedAt +} + +func (r *Release) IsDir() bool { + return true +} + +func (r *Release) GetHash() utils.HashInfo { + return utils.HashInfo{} +} + +func (r *Release) GetID() string { + return fmt.Sprintf("%d", r.ID) +} + +func (r *Release) GetPath() string { + return r.TagName +} + +func (r *Release) Children() ([]model.Obj, error) { + return utils.SliceConvert(r.Assets, func(src Asset) (model.Obj, error) { + return &src, nil + }) +} + +type Asset struct { + URL string `json:"url"` + BrowserDownloadURL string `json:"browser_download_url"` + ID int64 `json:"id"` + NodeID string `json:"node_id"` + Name string `json:"name"` + Label string `json:"label"` + State string `json:"state"` + ContentType string `json:"content_type"` + Size int64 `json:"size"` + DownloadCount int64 `json:"download_count"` + CreatedAt *time.Time `json:"created_at"` + UpdatedAt *time.Time `json:"updated_at"` + Uploader *User `json:"uploader"` +} + +func (a *Asset) UnmarshalJSON(data []byte) error { + type alias Asset + aux := struct { + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` + *alias + }{ + alias: (*alias)(a), + } + + if err := utils.Json.Unmarshal(data, &aux); err != nil { + return errors.Wrap(err, "failed to unmarshal asset") + } + + createdAt, err := time.Parse(time.RFC3339, aux.CreatedAt) + if err != nil { + return errors.Wrap(err, "failed to parse created_at in asset") + } + + a.CreatedAt = &createdAt + + updatedAt, err := time.Parse(time.RFC3339, aux.UpdatedAt) + if err != nil { + return errors.Wrap(err, "failed to parse updated_at in asset") + } + + a.UpdatedAt = &updatedAt + + return nil +} +func (a *Asset) GetSize() (_ int64) { + return a.Size +} + +func (a *Asset) GetName() (_ string) { + return a.Name +} + +func (a *Asset) ModTime() (_ time.Time) { + if a.UpdatedAt == nil { + return time.Time{} + } + return *a.UpdatedAt +} + +func (a *Asset) CreateTime() (_ time.Time) { + if a.CreatedAt == nil { + return time.Time{} + } + return *a.CreatedAt +} + +func (a *Asset) IsDir() bool { + return false +} + +// GetHash 获取文件的哈希值. github release api 不提供哈希值 +func (a *Asset) GetHash() utils.HashInfo { + return utils.HashInfo{} +} + +func (a *Asset) GetID() string { + return fmt.Sprintf("%d", a.ID) +} + +// GetPath 获取路径. 通过解析 Asset.BrowserDownloadURL 获取 +func (a *Asset) GetPath() string { + pattern := `https://github.com/[^/]+/[^/]+/releases/download/([^/]+)/([^/]+)` + re := regexp.MustCompile(pattern) + matches := re.FindStringSubmatch(a.BrowserDownloadURL) + if len(matches) != 3 { + return "" + } + tag := matches[1] + assetName := matches[2] + return fmt.Sprintf("%s/%s", tag, assetName) +} + +type User struct { + Name string `json:"name"` + Email string `json:"email"` + Login string `json:"login"` + ID int64 `json:"id"` + NodeID string `json:"node_id"` + AvatarURL string `json:"avatar_url"` + GravatarID string `json:"gravatar_id"` + URL string `json:"url"` + HTMLURL string `json:"html_url"` + FollowersURL string `json:"followers_url"` + FollowingURL string `json:"following_url"` + GistsURL string `json:"gists_url"` + StarredURL string `json:"starred_url"` + SubscriptionsURL string `json:"subscriptions_url"` + OrganizationsURL string `json:"organizations_url"` + ReposURL string `json:"repos_url"` + EventsURL string `json:"events_url"` + ReceivedEventsURL string `json:"received_events_url"` + Type string `json:"type"` + SiteAdmin bool `json:"site_admin"` +} diff --git a/drivers/github_release/types_test.go b/drivers/github_release/types_test.go new file mode 100644 index 00000000000..dd24fbe767b --- /dev/null +++ b/drivers/github_release/types_test.go @@ -0,0 +1,477 @@ +package template + +import ( + "testing" + "time" + + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/stretchr/testify/assert" +) + +func TestNewRepository(t *testing.T) { + tests := []struct { + name string + input string + want repository + wantErr bool + }{ + { + name: "正常的仓库名称", + input: "alist-org/alist", + want: repository{ + owner: "alist-org", + name: "alist", + }, + wantErr: false, + }, + { + name: "缺少斜杠的仓库名称", + input: "alist-org", + want: repository{}, + wantErr: true, + }, + { + name: "空的所有者", + input: "/alist", + want: repository{}, + wantErr: true, + }, + { + name: "空的仓库名", + input: "alist-org/", + want: repository{}, + wantErr: true, + }, + { + name: "完全空的输入", + input: "", + want: repository{}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := newRepository(tt.input) + if tt.wantErr { + assert.Error(t, err) + assert.Equal(t, repository{}, got) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + } + }) + } +} + +func TestRepository_String(t *testing.T) { + repo := repository{ + owner: "alist-org", + name: "alist", + } + assert.Equal(t, "alist-org/alist", repo.String()) +} + +func TestRepository_UrlEncode(t *testing.T) { + tests := []struct { + name string + repo repository + want string + }{ + { + name: "普通仓库名称", + repo: repository{ + owner: "alist-org", + name: "alist", + }, + want: "alist-org/alist", + }, + { + name: "包含特殊字符的仓库名称", + repo: repository{ + owner: "user name", + name: "repo name", + }, + want: "user+name/repo+name", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.repo.UrlEncode() + assert.Equal(t, tt.want, got) + }) + } +} + +func TestRelease_UnmarshalJSON(t *testing.T) { + tests := []struct { + name string + json string + want *Release + invalidDatetime bool + }{ + { + name: "正常的发布数据", + json: `{ + "url": "https://api.github.com/repos/alist-org/alist/releases/1", + "html_url": "https://github.com/alist-org/alist/releases/tag/v1.0.0", + "tag_name": "v1.0.0", + "name": "Release v1.0.0", + "body": "Release notes", + "created_at": "2023-01-01T12:00:00Z", + "published_at": "2023-01-01T12:30:00Z", + "author": { + "login": "test-user", + "id": 1 + } + }`, + want: &Release{ + URL: "https://api.github.com/repos/alist-org/alist/releases/1", + HTMLURL: "https://github.com/alist-org/alist/releases/tag/v1.0.0", + TagName: "v1.0.0", + Name: "Release v1.0.0", + Body: "Release notes", + Author: User{ + Login: "test-user", + ID: 1, + }, + }, + invalidDatetime: false, + }, + { + name: "无效的时间格式", + json: `{ + "created_at": "invalid-time", + "published_at": "invalid-time" + }`, + want: nil, + invalidDatetime: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var release Release + err := release.UnmarshalJSON([]byte(tt.json)) + if tt.invalidDatetime { + assert.True(t, release.CreatedAt.IsZero()) + assert.True(t, release.PublishedAt.IsZero()) + } else { + assert.NoError(t, err) + // 验证时间字段 + assert.Equal(t, 2023, release.CreatedAt.Year()) + assert.Equal(t, 2023, release.PublishedAt.Year()) + // 验证其他字段 + assert.Equal(t, tt.want.URL, release.URL) + assert.Equal(t, tt.want.HTMLURL, release.HTMLURL) + assert.Equal(t, tt.want.TagName, release.TagName) + assert.Equal(t, tt.want.Name, release.Name) + assert.Equal(t, tt.want.Body, release.Body) + assert.Equal(t, tt.want.Author.Login, release.Author.Login) + assert.Equal(t, tt.want.Author.ID, release.Author.ID) + } + }) + } +} + +func TestAsset_UnmarshalJSON(t *testing.T) { + tests := []struct { + name string + json string + want *Asset + wantErr bool + }{ + { + name: "正常的资源数据", + json: `{ + "url": "https://api.github.com/repos/alist-org/alist/releases/assets/1", + "browser_download_url": "https://github.com/alist-org/alist/releases/download/v1.0.0/asset.zip", + "id": 1, + "name": "asset.zip", + "label": "Binary", + "state": "uploaded", + "content_type": "application/zip", + "size": 1024, + "download_count": 100, + "created_at": "2023-01-01T12:00:00Z", + "updated_at": "2023-01-01T12:30:00Z", + "uploader": { + "login": "test-user", + "id": 1 + } + }`, + want: &Asset{ + URL: "https://api.github.com/repos/alist-org/alist/releases/assets/1", + BrowserDownloadURL: "https://github.com/alist-org/alist/releases/download/v1.0.0/asset.zip", + ID: 1, + Name: "asset.zip", + Label: "Binary", + State: "uploaded", + ContentType: "application/zip", + Size: 1024, + DownloadCount: 100, + Uploader: &User{ + Login: "test-user", + ID: 1, + }, + }, + wantErr: false, + }, + { + name: "无效的时间格式", + json: `{ + "created_at": "invalid-time", + "updated_at": "2023-01-01T12:30:00Z" + }`, + want: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var asset Asset + err := asset.UnmarshalJSON([]byte(tt.json)) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + // 验证时间字段 + assert.Equal(t, 2023, asset.CreatedAt.Year()) + assert.Equal(t, 2023, asset.UpdatedAt.Year()) + // 验证其他字段 + assert.Equal(t, tt.want.URL, asset.URL) + assert.Equal(t, tt.want.BrowserDownloadURL, asset.BrowserDownloadURL) + assert.Equal(t, tt.want.ID, asset.ID) + assert.Equal(t, tt.want.Name, asset.Name) + assert.Equal(t, tt.want.Label, asset.Label) + assert.Equal(t, tt.want.State, asset.State) + assert.Equal(t, tt.want.ContentType, asset.ContentType) + assert.Equal(t, tt.want.Size, asset.Size) + assert.Equal(t, tt.want.DownloadCount, asset.DownloadCount) + assert.Equal(t, tt.want.Uploader.Login, asset.Uploader.Login) + assert.Equal(t, tt.want.Uploader.ID, asset.Uploader.ID) + } + }) + } +} + +func TestReleases_UnmarshalJSON(t *testing.T) { + jsonData := `[ + { + "url": "https://api.github.com/repos/AlistGo/alist/releases/170718825", + "assets_url": "https://api.github.com/repos/AlistGo/alist/releases/170718825/assets", + "upload_url": "https://uploads.github.com/repos/AlistGo/alist/releases/170718825/assets{?name,label}", + "html_url": "https://github.com/AlistGo/alist/releases/tag/beta", + "id": 170718825, + "author": { + "login": "xhofe", + "id": 36558727, + "node_id": "MDQ6VXNlcjM2NTU4NzI3", + "avatar_url": "https://avatars.githubusercontent.com/u/36558727?v=4", + "url": "https://api.github.com/users/xhofe", + "html_url": "https://github.com/xhofe", + "type": "User", + "site_admin": false + }, + "node_id": "RE_kwDOE09S284KLPZp", + "tag_name": "beta", + "target_commitish": "main", + "name": "AList Beta Version", + "draft": false, + "prerelease": true, + "created_at": "2025-01-18T15:52:02Z", + "published_at": "2024-08-17T14:10:08Z", + "assets": [ + { + "url": "https://api.github.com/repos/AlistGo/alist/releases/assets/221414212", + "id": 221414212, + "name": "alist-android-386.tar.gz", + "content_type": "application/gzip", + "state": "uploaded", + "size": 31186443, + "download_count": 6, + "created_at": "2025-01-18T15:58:55Z", + "updated_at": "2025-01-18T15:58:56Z", + "browser_download_url": "https://github.com/AlistGo/alist/releases/download/beta/alist-android-386.tar.gz", + "uploader": { + "login": "github-actions[bot]", + "id": 41898282, + "type": "Bot", + "site_admin": false + } + }, + { + "url": "https://api.github.com/repos/AlistGo/alist/releases/assets/221414214", + "id": 221414214, + "name": "alist-android-amd64.tar.gz", + "content_type": "application/gzip", + "state": "uploaded", + "size": 31586093, + "download_count": 10, + "created_at": "2025-01-18T15:58:55Z", + "updated_at": "2025-01-18T15:58:56Z", + "browser_download_url": "https://github.com/AlistGo/alist/releases/download/beta/alist-android-amd64.tar.gz", + "uploader": { + "login": "github-actions[bot]", + "id": 41898282, + "type": "Bot", + "site_admin": false + } + } + ], + "body": "Test text" + } + ]` + + var releases []Release + err := utils.Json.Unmarshal([]byte(jsonData), &releases) + assert.NoError(t, err) + assert.Len(t, releases, 1) + + release := releases[0] + // 验证 Release 基本信息 + assert.Equal(t, int64(170718825), release.ID) + assert.Equal(t, "beta", release.TagName) + assert.Equal(t, "AList Beta Version", release.Name) + assert.Equal(t, "Test text", release.Body) + assert.False(t, release.Draft) + assert.True(t, release.Prerelease) + + // 验证时间 + assert.Equal(t, 2025, release.CreatedAt.Year()) + assert.Equal(t, 2024, release.PublishedAt.Year()) + + // 验证作者信息 + assert.Equal(t, "xhofe", release.Author.Login) + assert.Equal(t, int64(36558727), release.Author.ID) + assert.Equal(t, "User", release.Author.Type) + + // 验证资源信息 + assert.Len(t, release.Assets, 2) + + // 验证第一个资源 + asset1 := release.Assets[0] + assert.Equal(t, int64(221414212), asset1.ID) + assert.Equal(t, "alist-android-386.tar.gz", asset1.Name) + assert.Equal(t, "application/gzip", asset1.ContentType) + assert.Equal(t, int64(31186443), asset1.Size) + assert.Equal(t, int64(6), asset1.DownloadCount) + assert.Equal(t, "uploaded", asset1.State) + assert.Equal(t, "https://github.com/AlistGo/alist/releases/download/beta/alist-android-386.tar.gz", asset1.BrowserDownloadURL) + + // 验证第一个资源的上传者 + assert.Equal(t, "github-actions[bot]", asset1.Uploader.Login) + assert.Equal(t, int64(41898282), asset1.Uploader.ID) + assert.Equal(t, "Bot", asset1.Uploader.Type) + + // 验证第二个资源 + asset2 := release.Assets[1] + assert.Equal(t, int64(221414214), asset2.ID) + assert.Equal(t, "alist-android-amd64.tar.gz", asset2.Name) + assert.Equal(t, int64(31586093), asset2.Size) + assert.Equal(t, int64(10), asset2.DownloadCount) +} + +func TestRelease_InterfaceMethods(t *testing.T) { + release := &Release{ + ID: 123, + TagName: "v1.0.0", + CreatedAt: time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC), + PublishedAt: time.Date(2023, 1, 2, 0, 0, 0, 0, time.UTC), + Assets: []Asset{ + {Name: "asset1.zip"}, + {Name: "asset2.tar.gz"}, + }, + } + + // 测试基本方法 + t.Run("basic methods", func(t *testing.T) { + assert.Equal(t, int64(0), release.GetSize()) + assert.Equal(t, "v1.0.0", release.GetName()) + assert.Equal(t, time.Date(2023, 1, 2, 0, 0, 0, 0, time.UTC), release.ModTime()) + assert.Equal(t, time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC), release.CreateTime()) + assert.True(t, release.IsDir()) + assert.Equal(t, utils.HashInfo{}, release.GetHash()) + assert.Equal(t, "123", release.GetID()) + assert.Equal(t, "v1.0.0", release.GetPath()) + }) + + // 测试 Children 方法 + t.Run("children", func(t *testing.T) { + children, err := release.Children() + assert.NoError(t, err) + assert.Len(t, children, 2) + assert.Equal(t, "asset1.zip", children[0].GetName()) + assert.Equal(t, "asset2.tar.gz", children[1].GetName()) + }) +} + +func TestAsset_InterfaceMethods(t *testing.T) { + now := time.Now() + asset := &Asset{ + ID: 456, + Name: "test.zip", + Size: 12345, + CreatedAt: &now, + UpdatedAt: &now, + BrowserDownloadURL: "https://github.com/owner/repo/releases/download/v1.0.0/test.zip", + } + + t.Run("basic methods", func(t *testing.T) { + assert.Equal(t, int64(12345), asset.GetSize()) + assert.Equal(t, "test.zip", asset.GetName()) + assert.Equal(t, now, asset.ModTime()) + assert.Equal(t, now, asset.CreateTime()) + assert.False(t, asset.IsDir()) + assert.Equal(t, utils.HashInfo{}, asset.GetHash()) + assert.Equal(t, "456", asset.GetID()) + }) + + // 测试空时间的情况 + t.Run("nil time fields", func(t *testing.T) { + emptyAsset := &Asset{} + assert.Equal(t, time.Time{}, emptyAsset.ModTime()) + assert.Equal(t, time.Time{}, emptyAsset.CreateTime()) + }) +} + +func TestAsset_GetPath(t *testing.T) { + tests := []struct { + name string + browserDownloadURL string + want string + }{ + { + name: "valid url", + browserDownloadURL: "https://github.com/owner/repo/releases/download/v1.0.0/test.zip", + want: "v1.0.0/test.zip", + }, + { + name: "invalid url format", + browserDownloadURL: "https://github.com/invalid/url", + want: "", + }, + { + name: "empty url", + browserDownloadURL: "", + want: "", + }, + { + name: "url with special characters", + browserDownloadURL: "https://github.com/owner/repo/releases/download/v1.0.0-beta/test-file_1.2.3.zip", + want: "v1.0.0-beta/test-file_1.2.3.zip", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + asset := &Asset{ + BrowserDownloadURL: tt.browserDownloadURL, + } + got := asset.GetPath() + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/drivers/github_release/util.go b/drivers/github_release/util.go new file mode 100644 index 00000000000..38cdfe4490d --- /dev/null +++ b/drivers/github_release/util.go @@ -0,0 +1 @@ +package template From c14fe2f7dd48297a88f99502cc73f2e9db70dcd3 Mon Sep 17 00:00:00 2001 From: Zhang JL Date: Tue, 21 Jan 2025 22:36:51 +0800 Subject: [PATCH 2/8] chore(github_release): enhance help messages for repo, token, and max_releases fields in Addition struct --- drivers/github_release/meta.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/drivers/github_release/meta.go b/drivers/github_release/meta.go index f925972391c..60c0252fae9 100644 --- a/drivers/github_release/meta.go +++ b/drivers/github_release/meta.go @@ -8,9 +8,9 @@ import ( type Addition struct { driver.RootID // define other - Repo string `json:"repo" required:"true" default:"AlistGo/alist"` - Token string `json:"token" required:"true" default:""` - MaxReleases int `json:"max_releases" required:"true" type:"number" default:"30" help:"max releases to list"` + Repo string `json:"repo" required:"true" default:"AlistGo/alist" help:"Repository name(owner/repo)"` + Token string `json:"token" required:"true" default:"" help:"Github personal access token"` + MaxReleases int `json:"max_releases" required:"true" type:"number" default:"30" help:"Max releases to list"` } var config = driver.Config{ From b266a31bb3b526a9fc56a13439bb1604e987a0a0 Mon Sep 17 00:00:00 2001 From: Zhang JL Date: Fri, 24 Jan 2025 17:44:53 +0800 Subject: [PATCH 3/8] feat(github_release): add latest release support enhance driver functionality with validation, concurrent release fetching --- drivers/github_release/driver.go | 130 ++++++++++++++++++++++++------- drivers/github_release/github.go | 37 +++++++++ drivers/github_release/meta.go | 1 + drivers/github_release/types.go | 9 +++ 4 files changed, 147 insertions(+), 30 deletions(-) diff --git a/drivers/github_release/driver.go b/drivers/github_release/driver.go index bc7427edb20..51b3d0fa087 100644 --- a/drivers/github_release/driver.go +++ b/drivers/github_release/driver.go @@ -12,8 +12,10 @@ import ( "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" "github.com/pkg/errors" + "golang.org/x/sync/errgroup" ) +// GithubRelease implements a driver for GitHub Release type GithubRelease struct { model.Storage Addition @@ -22,6 +24,7 @@ type GithubRelease struct { repo repository } +// Config returns the driver config func (d *GithubRelease) Config() driver.Config { return config } @@ -30,9 +33,9 @@ func (d *GithubRelease) GetAddition() driver.Additional { return &d.Addition } -func (d *GithubRelease) Init(ctx context.Context) error { - token := d.Addition.Token - if token == "" { +// validate checks if the driver configuration is valid +func (d *GithubRelease) validate() error { + if d.Addition.Token == "" { return errs.EmptyToken } @@ -44,80 +47,142 @@ func (d *GithubRelease) Init(ctx context.Context) error { d.Addition.MaxReleases = 100 } - d.api = NewApiContext(token, nil) + return nil +} + +// Init initializes the driver +func (d *GithubRelease) Init(ctx context.Context) error { + if err := d.validate(); err != nil { + return err + } + + d.api = NewApiContext(d.Addition.Token, nil) repo, err := newRepository(d.Addition.Repo) if err != nil { - return err + return errors.Wrap(err, "failed to create repository") } d.repo = repo return nil } -// Drop Delete this driver +// Drop deletes this driver func (d *GithubRelease) Drop(ctx context.Context) error { return nil } -func (d *GithubRelease) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { - repo, err := newRepository(d.Addition.Repo) +// listReleases gets all releases +func (d *GithubRelease) listReleases(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + g, ctx := errgroup.WithContext(ctx) + + var releases []model.Obj + var latest model.Obj + + // Get latest release if enabled + if d.Addition.ShowLatest { + g.Go(func() error { + release, err := d.api.GetLatestRelease(d.repo) + if err != nil { + if err == ErrNoRelease { + // for no release, just return + return nil + } + return errors.Wrap(err, "failed to get latest release") + } + latest = release + return nil + }) + } + + // Get all releases + g.Go(func() error { + r, err := d.api.GetReleases(d.repo, d.Addition.MaxReleases) + if err != nil { + return errors.Wrap(err, "failed to get releases") + } + releases = r + return nil + }) + + // Wait for all goroutines to complete + if err := g.Wait(); err != nil { + return nil, err + } + + // Add latest release to the top if available + if latest != nil && releases != nil { + releases = append([]model.Obj{latest}, releases...) + } + + return releases, nil +} + +func (d *GithubRelease) listReleaseAssets(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + idStr := dir.GetID() + id, err := strconv.ParseInt(idStr, 10, 64) + if err != nil { + return nil, errors.Wrapf(err, "list release %s failed, id is not a number", idStr) + } + release, err := d.api.GetRelease(d.repo, id) if err != nil { return nil, err } + return release.Children() +} - // 判断 dir 是不是挂在点。如果 dir 是挂载点,则返回所有的 release; - // 如果 dir 不是挂载点,则返回 dir 下的 release。 +// List returns the objects in the given directory +func (d *GithubRelease) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + // If dir is root, return all releases if dir.GetPath() == "" { - releases, err := d.api.GetReleases(repo, d.Addition.MaxReleases) - if err != nil { - return nil, err - } - return releases, nil + return d.listReleases(ctx, dir, args) } + // Otherwise return release assets idStr := dir.GetID() id, err := strconv.ParseInt(idStr, 10, 64) if err != nil { return nil, errors.Wrapf(err, "list release %s failed, id is not a number", idStr) } - release, err := d.api.GetRelease(repo, id) + release, err := d.api.GetRelease(d.repo, id) if err != nil { - return nil, err + return nil, errors.Wrap(err, "failed to get release") } return release.Children() } +// proxyDownload checks if download should be proxied func (d *GithubRelease) proxyDownload(file model.Obj, args model.LinkArgs) bool { + // Must proxy if configured if d.Config().MustProxy() || d.GetStorage().WebProxy { return true } - req := args.HttpReq - if args.HttpReq != nil && - req.URL != nil && - strings.HasPrefix(req.URL.Path, fmt.Sprintf("/p%s", d.GetStorage().MountPath)) { - return true + // Check if request path indicates proxy is needed + if req := args.HttpReq; req != nil && req.URL != nil { + proxyPath := fmt.Sprintf("/p%s", d.GetStorage().MountPath) + return strings.HasPrefix(req.URL.Path, proxyPath) } return false } +// Link returns the download link for a file func (d *GithubRelease) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { idStr := file.GetID() id, err := strconv.ParseInt(idStr, 10, 64) if err != nil { return nil, errors.Wrapf(err, "get link of file %s failed, id is not a number", idStr) } + asset, err := d.api.GetReleaseAsset(d.repo, id) if err != nil { - return nil, err + return nil, errors.Wrap(err, "failed to get release asset") } if d.proxyDownload(file, args) { - header := http.Header{ "User-Agent": {"Alist/" + conf.VERSION}, "Accept": {"application/octet-stream"}, @@ -133,36 +198,41 @@ func (d *GithubRelease) Link(ctx context.Context, file model.Obj, args model.Lin return &model.Link{ URL: asset.BrowserDownloadURL, }, nil - } +// MakeDir is not supported func (d *GithubRelease) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) (model.Obj, error) { return nil, errs.NotSupport } +// Move is not supported func (d *GithubRelease) Move(ctx context.Context, srcObj, dstDir model.Obj) (model.Obj, error) { return nil, errs.NotSupport } +// Rename is not supported func (d *GithubRelease) Rename(ctx context.Context, srcObj model.Obj, newName string) (model.Obj, error) { - // TODO rename obj, optional - return nil, errs.NotImplement + return nil, errs.NotSupport } +// Copy is not supported func (d *GithubRelease) Copy(ctx context.Context, srcObj, dstDir model.Obj) (model.Obj, error) { return nil, errs.NotSupport } +// Remove is not supported func (d *GithubRelease) Remove(ctx context.Context, obj model.Obj) error { return errs.NotSupport } +// Put is not supported func (d *GithubRelease) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { return nil, errs.NotSupport } -//func (d *Template) Other(ctx context.Context, args model.OtherArgs) (interface{}, error) { -// return nil, errs.NotSupport -//} +// Other implements custom operations +func (d *GithubRelease) Other(ctx context.Context, args model.OtherArgs) (interface{}, error) { + return nil, errs.NotSupport +} var _ driver.Driver = (*GithubRelease)(nil) diff --git a/drivers/github_release/github.go b/drivers/github_release/github.go index 9b532d0c6dd..cf6f2e6c5a5 100644 --- a/drivers/github_release/github.go +++ b/drivers/github_release/github.go @@ -216,3 +216,40 @@ func (a *ApiContext) GetReleaseAsset(repo repository, ID int64) (*Asset, error) return &asset, nil } + +var ( + ErrNoRelease = errors.New("no release found") +) + +// GetLatestRelease 获取最新 release. +func (a *ApiContext) GetLatestRelease(repo repository) (model.Obj, error) { + url := fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", repo.UrlEncode()) + response, err := a.getWithRetry(url) + if err != nil { + return nil, err + } + defer response.Body.Close() + + body, err := io.ReadAll(response.Body) + if err != nil { + return nil, errors.Wrap(err, "get latest release failed") + } + + if response.StatusCode != http.StatusOK { + if response.StatusCode == http.StatusNotFound { + // identify no release + return nil, ErrNoRelease + } + return nil, parseHTTPError(body) + } + + release := Release{} + err = utils.Json.Unmarshal(body, &release) + if err != nil { + return nil, errors.Wrap(err, "get latest release failed") + } + + release.SetLatestFlag(true) + + return &release, nil +} diff --git a/drivers/github_release/meta.go b/drivers/github_release/meta.go index 60c0252fae9..5ff3f10891e 100644 --- a/drivers/github_release/meta.go +++ b/drivers/github_release/meta.go @@ -11,6 +11,7 @@ type Addition struct { Repo string `json:"repo" required:"true" default:"AlistGo/alist" help:"Repository name(owner/repo)"` Token string `json:"token" required:"true" default:"" help:"Github personal access token"` MaxReleases int `json:"max_releases" required:"true" type:"number" default:"30" help:"Max releases to list"` + ShowLatest bool `json:"show_latest" type:"bool" default:"true" help:"Show latest release on top"` } var config = driver.Config{ diff --git a/drivers/github_release/types.go b/drivers/github_release/types.go index ca6205ec114..56d4018419b 100644 --- a/drivers/github_release/types.go +++ b/drivers/github_release/types.go @@ -62,6 +62,8 @@ type Release struct { BodyText string `json:"body_text"` MentionsCount int `json:"mentions_count"` DiscussionURL string `json:"discussion_url"` + + latest bool } func (r *Release) UnmarshalJSON(data []byte) error { @@ -101,7 +103,14 @@ func (r *Release) GetSize() int64 { return 0 } +func (r *Release) SetLatestFlag(flag bool) { + r.latest = flag +} + func (r *Release) GetName() string { + if r.latest { + return fmt.Sprintf("latest(%s)", r.TagName) + } return r.TagName } From 7b1dd2fdeb7a70bda0f0367869045e2ae9c05972 Mon Sep 17 00:00:00 2001 From: Zhang JL Date: Fri, 24 Jan 2025 22:40:18 +0800 Subject: [PATCH 4/8] chore(github_release): adjust backoff parameters for retry logic --- drivers/github_release/backoff.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/drivers/github_release/backoff.go b/drivers/github_release/backoff.go index f6edb2c1cae..875cde09f7a 100644 --- a/drivers/github_release/backoff.go +++ b/drivers/github_release/backoff.go @@ -7,8 +7,8 @@ import ( const ( initialRetryInterval = 500 * time.Millisecond - maxInterval = 1 * time.Minute - maxElapsedTime = 15 * time.Minute + maxInterval = 10 * time.Second + maxElapsedTime = 30 * time.Second randomizationFactor = 0.5 multiplier = 1.5 ) From becff3355ac09341fc5a4c3877a369b561b12b5f Mon Sep 17 00:00:00 2001 From: Zhang JL Date: Fri, 24 Jan 2025 22:54:52 +0800 Subject: [PATCH 5/8] refactor(github_release): update API methods to accept context for improved concurrency and error handling --- drivers/github_release/driver.go | 10 +++--- drivers/github_release/github.go | 52 +++++++++++++++++++++++--------- 2 files changed, 43 insertions(+), 19 deletions(-) diff --git a/drivers/github_release/driver.go b/drivers/github_release/driver.go index 51b3d0fa087..f9b574406f9 100644 --- a/drivers/github_release/driver.go +++ b/drivers/github_release/driver.go @@ -82,7 +82,7 @@ func (d *GithubRelease) listReleases(ctx context.Context, dir model.Obj, args mo // Get latest release if enabled if d.Addition.ShowLatest { g.Go(func() error { - release, err := d.api.GetLatestRelease(d.repo) + release, err := d.api.GetLatestRelease(ctx, d.repo) if err != nil { if err == ErrNoRelease { // for no release, just return @@ -97,7 +97,7 @@ func (d *GithubRelease) listReleases(ctx context.Context, dir model.Obj, args mo // Get all releases g.Go(func() error { - r, err := d.api.GetReleases(d.repo, d.Addition.MaxReleases) + r, err := d.api.GetReleases(ctx, d.repo, d.Addition.MaxReleases) if err != nil { return errors.Wrap(err, "failed to get releases") } @@ -124,7 +124,7 @@ func (d *GithubRelease) listReleaseAssets(ctx context.Context, dir model.Obj, ar if err != nil { return nil, errors.Wrapf(err, "list release %s failed, id is not a number", idStr) } - release, err := d.api.GetRelease(d.repo, id) + release, err := d.api.GetRelease(ctx, d.repo, id) if err != nil { return nil, err } @@ -145,7 +145,7 @@ func (d *GithubRelease) List(ctx context.Context, dir model.Obj, args model.List return nil, errors.Wrapf(err, "list release %s failed, id is not a number", idStr) } - release, err := d.api.GetRelease(d.repo, id) + release, err := d.api.GetRelease(ctx, d.repo, id) if err != nil { return nil, errors.Wrap(err, "failed to get release") } @@ -177,7 +177,7 @@ func (d *GithubRelease) Link(ctx context.Context, file model.Obj, args model.Lin return nil, errors.Wrapf(err, "get link of file %s failed, id is not a number", idStr) } - asset, err := d.api.GetReleaseAsset(d.repo, id) + asset, err := d.api.GetReleaseAsset(ctx, d.repo, id) if err != nil { return nil, errors.Wrap(err, "failed to get release asset") } diff --git a/drivers/github_release/github.go b/drivers/github_release/github.go index cf6f2e6c5a5..c7983860d98 100644 --- a/drivers/github_release/github.go +++ b/drivers/github_release/github.go @@ -1,6 +1,7 @@ package template import ( + "context" "fmt" "io" "net/http" @@ -54,12 +55,29 @@ func parseHTTPError(body []byte) error { return errors.New(message) } +// sleepWithContext 在指定的时间内等待, 如果 context 被取消则提前返回. +func sleepWithContext(ctx context.Context, d time.Duration) error { + timer := time.NewTimer(d) + defer timer.Stop() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} + // getWithRetry 获取 GitHub API 并重试. -func (a *ApiContext) getWithRetry(url string) (*http.Response, error) { +func (a *ApiContext) getWithRetry(ctx context.Context, url string) (*http.Response, error) { backoff := Backoff{} for { - response, err := a.get(url) + if err := ctx.Err(); err != nil { + return nil, err + } + + response, err := a.get(ctx, url) // non-2xx code does not cause error if err != nil { @@ -69,7 +87,10 @@ func (a *ApiContext) getWithRetry(url string) (*http.Response, error) { return nil, errors.Wrap(err, "request failed") } utils.Log.Debugf("query github api error: %s, retry after %s", err, p) - time.Sleep(p) + + if err := sleepWithContext(ctx, p); err != nil { + return nil, err + } continue } @@ -97,7 +118,10 @@ func (a *ApiContext) getWithRetry(url string) (*http.Response, error) { return nil, parseHTTPError(body) } utils.Log.Debugf("query github api error: status code %d, retry after %s", response.StatusCode, p) - time.Sleep(p) + + if err := sleepWithContext(ctx, p); err != nil { + return nil, err + } continue } @@ -112,8 +136,8 @@ func (a *ApiContext) SetAuthHeader(header http.Header) { } // get 获取 GitHub API. -func (a *ApiContext) get(url string) (*http.Response, error) { - request, err := http.NewRequest("GET", url, nil) +func (a *ApiContext) get(ctx context.Context, url string) (*http.Response, error) { + request, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { return nil, err } @@ -130,12 +154,12 @@ func (a *ApiContext) get(url string) (*http.Response, error) { } // GetReleases 获取仓库信息. -func (a *ApiContext) GetReleases(repo repository, perPage int) ([]model.Obj, error) { +func (a *ApiContext) GetReleases(ctx context.Context, repo repository, perPage int) ([]model.Obj, error) { if perPage < 1 { perPage = 30 } url := fmt.Sprintf("https://api.github.com/repos/%s/releases?per_page=%d", repo.UrlEncode(), perPage) - response, err := a.getWithRetry(url) + response, err := a.getWithRetry(ctx, url) if err != nil { return nil, err } @@ -164,9 +188,9 @@ func (a *ApiContext) GetReleases(repo repository, perPage int) ([]model.Obj, err } // GetRelease 获取指定 tag 的 release. -func (a *ApiContext) GetRelease(repo repository, id int64) (*Release, error) { +func (a *ApiContext) GetRelease(ctx context.Context, repo repository, id int64) (*Release, error) { url := fmt.Sprintf("https://api.github.com/repos/%s/releases/%d", repo.UrlEncode(), id) - response, err := a.getWithRetry(url) + response, err := a.getWithRetry(ctx, url) if err != nil { return nil, err } @@ -191,9 +215,9 @@ func (a *ApiContext) GetRelease(repo repository, id int64) (*Release, error) { } // GetReleaseAsset 获取指定 tag 的 release 的 assets. -func (a *ApiContext) GetReleaseAsset(repo repository, ID int64) (*Asset, error) { +func (a *ApiContext) GetReleaseAsset(ctx context.Context, repo repository, ID int64) (*Asset, error) { url := fmt.Sprintf("https://api.github.com/repos/%s/releases/assets/%d", repo.UrlEncode(), ID) - response, err := a.getWithRetry(url) + response, err := a.getWithRetry(ctx, url) if err != nil { return nil, err } @@ -222,9 +246,9 @@ var ( ) // GetLatestRelease 获取最新 release. -func (a *ApiContext) GetLatestRelease(repo repository) (model.Obj, error) { +func (a *ApiContext) GetLatestRelease(ctx context.Context, repo repository) (model.Obj, error) { url := fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", repo.UrlEncode()) - response, err := a.getWithRetry(url) + response, err := a.getWithRetry(ctx, url) if err != nil { return nil, err } From 009a46aca64b6860325bb7e8e50e60ff2f920ec5 Mon Sep 17 00:00:00 2001 From: Zhang JL Date: Fri, 24 Jan 2025 23:10:24 +0800 Subject: [PATCH 6/8] chore(github_release): rename package --- drivers/github_release/backoff.go | 2 +- drivers/github_release/backoff_test.go | 2 +- drivers/github_release/driver.go | 2 +- drivers/github_release/github.go | 2 +- drivers/github_release/meta.go | 2 +- drivers/github_release/types.go | 2 +- drivers/github_release/types_test.go | 2 +- drivers/github_release/util.go | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) diff --git a/drivers/github_release/backoff.go b/drivers/github_release/backoff.go index 875cde09f7a..224e783be97 100644 --- a/drivers/github_release/backoff.go +++ b/drivers/github_release/backoff.go @@ -1,4 +1,4 @@ -package template +package github_release import ( "math/rand" diff --git a/drivers/github_release/backoff_test.go b/drivers/github_release/backoff_test.go index f6f088ef6b3..d2d13a36b11 100644 --- a/drivers/github_release/backoff_test.go +++ b/drivers/github_release/backoff_test.go @@ -1,4 +1,4 @@ -package template +package github_release import ( "testing" diff --git a/drivers/github_release/driver.go b/drivers/github_release/driver.go index f9b574406f9..a6f6d69f1ba 100644 --- a/drivers/github_release/driver.go +++ b/drivers/github_release/driver.go @@ -1,4 +1,4 @@ -package template +package github_release import ( "context" diff --git a/drivers/github_release/github.go b/drivers/github_release/github.go index c7983860d98..5505122e122 100644 --- a/drivers/github_release/github.go +++ b/drivers/github_release/github.go @@ -1,4 +1,4 @@ -package template +package github_release import ( "context" diff --git a/drivers/github_release/meta.go b/drivers/github_release/meta.go index 5ff3f10891e..00e495a32b1 100644 --- a/drivers/github_release/meta.go +++ b/drivers/github_release/meta.go @@ -1,4 +1,4 @@ -package template +package github_release import ( "github.com/alist-org/alist/v3/internal/driver" diff --git a/drivers/github_release/types.go b/drivers/github_release/types.go index 56d4018419b..4751c135104 100644 --- a/drivers/github_release/types.go +++ b/drivers/github_release/types.go @@ -1,4 +1,4 @@ -package template +package github_release import ( "fmt" diff --git a/drivers/github_release/types_test.go b/drivers/github_release/types_test.go index dd24fbe767b..9bc43bfdb4c 100644 --- a/drivers/github_release/types_test.go +++ b/drivers/github_release/types_test.go @@ -1,4 +1,4 @@ -package template +package github_release import ( "testing" diff --git a/drivers/github_release/util.go b/drivers/github_release/util.go index 38cdfe4490d..eb1164d7976 100644 --- a/drivers/github_release/util.go +++ b/drivers/github_release/util.go @@ -1 +1 @@ -package template +package github_release From 39a924d4c37c8fd1df8d1e6069065223e63ee22c Mon Sep 17 00:00:00 2001 From: Zhang JL Date: Sat, 25 Jan 2025 00:10:56 +0800 Subject: [PATCH 7/8] fix(github_release): correcting a test iteration number --- drivers/github_release/backoff_test.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/drivers/github_release/backoff_test.go b/drivers/github_release/backoff_test.go index d2d13a36b11..288e57ad181 100644 --- a/drivers/github_release/backoff_test.go +++ b/drivers/github_release/backoff_test.go @@ -7,11 +7,12 @@ import ( func TestBackoffMultiple(t *testing.T) { b := &Backoff{} - for i := 0; i < 19; i++ { + for i := 0; i < 10; i++ { p, ok := b.Pause() t.Logf("iteration %d pausing for %s", i, p) if !ok { - t.Fatalf("hit the pause timeout after %d pauses", i) + t.Logf("hit the pause timeout after %d pauses", i) + return } } } From 0bce51dd539aad1e2a0f34338ff0c754ce7d071f Mon Sep 17 00:00:00 2001 From: Zhang JL Date: Sat, 25 Jan 2025 00:14:39 +0800 Subject: [PATCH 8/8] fix(github_release): corect the error logic when the latest release is not found * rename APIContext * add github_test.go * add github api rate limit log --- drivers/github_release/driver.go | 4 +- drivers/github_release/github.go | 207 ++++++++++++++++---------- drivers/github_release/github_test.go | 155 +++++++++++++++++++ 3 files changed, 284 insertions(+), 82 deletions(-) create mode 100644 drivers/github_release/github_test.go diff --git a/drivers/github_release/driver.go b/drivers/github_release/driver.go index a6f6d69f1ba..18dd57b54b4 100644 --- a/drivers/github_release/driver.go +++ b/drivers/github_release/driver.go @@ -20,7 +20,7 @@ type GithubRelease struct { model.Storage Addition - api *ApiContext + api *APIContext repo repository } @@ -56,7 +56,7 @@ func (d *GithubRelease) Init(ctx context.Context) error { return err } - d.api = NewApiContext(d.Addition.Token, nil) + d.api = NewAPIContext(d.Addition.Token, nil) repo, err := newRepository(d.Addition.Repo) if err != nil { diff --git a/drivers/github_release/github.go b/drivers/github_release/github.go index 5505122e122..0edbbd8b103 100644 --- a/drivers/github_release/github.go +++ b/drivers/github_release/github.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "net/http" + "strconv" "time" "github.com/alist-org/alist/v3/internal/model" @@ -12,47 +13,83 @@ import ( "github.com/pkg/errors" ) -const GITHUB_API_VERSION = "2022-11-28" +const ( + GITHUB_API_VERSION = "2022-11-28" + DEFAULT_TIMEOUT = 10 * time.Second +) -type ApiContext struct { - token string - version string - client *http.Client -} +var ErrRateLimitExceeded = errors.New("rate limit exceeded") -func NewApiContext(token string, client *http.Client) *ApiContext { - ret := ApiContext{ - token: token, - version: GITHUB_API_VERSION, - client: client, - } +// RateLimit 表示 GitHub API 的速率限制信息 +type RateLimit struct { + Limit uint + Remaining uint + Reset time.Time +} - if ret.client == nil { - ret.client = http.DefaultClient - } +// GitHubError 表示 GitHub API 返回的错误信息 +type GitHubError struct { + Message string `json:"message"` + DocumentationURL string `json:"documentation_url"` + StatusCode int +} - return &ret +func (e *GitHubError) Error() string { + return fmt.Sprintf("github api error: %s (status: %d)", e.Message, e.StatusCode) } -// parseHTTPError 解析 HTTP 错误. -func parseHTTPError(body []byte) error { - var v map[string]interface{} +// parseHTTPError 解析 GitHub API 的错误响应 +func parseHTTPError(statusCode int, body []byte) error { + var v GitHubError err := utils.Json.Unmarshal(body, &v) if err != nil { - return errors.New(string(body)) + return &GitHubError{ + Message: string(body), + StatusCode: statusCode, + } + } + v.StatusCode = statusCode + return &v +} + +// parseRateLimit 从响应头中解析速率限制信息 +func parseRateLimit(header http.Header) *RateLimit { + limit, _ := strconv.Atoi(header.Get("X-RateLimit-Limit")) + remaining, _ := strconv.Atoi(header.Get("X-RateLimit-Remaining")) + reset, _ := strconv.ParseInt(header.Get("X-RateLimit-Reset"), 10, 64) + + return &RateLimit{ + Limit: uint(limit), + Remaining: uint(remaining), + Reset: time.Unix(reset, 0), } +} - iface, ok := v["message"] - if !ok { - return errors.New(string(body)) +// APIContext 表示 GitHub API 的上下文信息 +type APIContext struct { + token string + version string + client *http.Client + defaultTimeout time.Duration + rateLimit *RateLimit +} + +// NewAPIContext 创建一个新的 GitHub API 上下文 +func NewAPIContext(token string, client *http.Client) *APIContext { + ret := APIContext{ + token: token, + version: GITHUB_API_VERSION, + client: client, + defaultTimeout: DEFAULT_TIMEOUT, } - message, ok := iface.(string) - if !ok { - return errors.New(string(body)) + if ret.client == nil { + ret.client = &http.Client{ + Timeout: ret.defaultTimeout, + } } - return errors.New(message) + return &ret } // sleepWithContext 在指定的时间内等待, 如果 context 被取消则提前返回. @@ -69,7 +106,7 @@ func sleepWithContext(ctx context.Context, d time.Duration) error { } // getWithRetry 获取 GitHub API 并重试. -func (a *ApiContext) getWithRetry(ctx context.Context, url string) (*http.Response, error) { +func (a *APIContext) getWithRetry(ctx context.Context, url string) (*http.Response, error) { backoff := Backoff{} for { @@ -81,6 +118,11 @@ func (a *ApiContext) getWithRetry(ctx context.Context, url string) (*http.Respon // non-2xx code does not cause error if err != nil { + // 如果错误是速率限制错误, 则直接返回 + if errors.Is(err, ErrRateLimitExceeded) { + return nil, err + } + // retry when error is not nil p, retryAgain := backoff.Pause() if !retryAgain { @@ -115,7 +157,7 @@ func (a *ApiContext) getWithRetry(ctx context.Context, url string) (*http.Respon // retry when server error p, retryAgain := backoff.Pause() if !retryAgain { - return nil, parseHTTPError(body) + return nil, parseHTTPError(response.StatusCode, body) } utils.Log.Debugf("query github api error: status code %d, retry after %s", response.StatusCode, p) @@ -125,18 +167,18 @@ func (a *ApiContext) getWithRetry(ctx context.Context, url string) (*http.Respon continue } - return nil, parseHTTPError(body) + return nil, parseHTTPError(response.StatusCode, body) } } // SetAuthHeader 为请求头添加 GitHub API 所需的认证头. // 这是一个副作用函数, 会直接修改传入的 header. -func (a *ApiContext) SetAuthHeader(header http.Header) { +func (a *APIContext) SetAuthHeader(header http.Header) { header.Set("Authorization", fmt.Sprintf("Bearer %s", a.token)) } // get 获取 GitHub API. -func (a *ApiContext) get(ctx context.Context, url string) (*http.Response, error) { +func (a *APIContext) get(ctx context.Context, url string) (*http.Response, error) { request, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { return nil, err @@ -150,11 +192,21 @@ func (a *ApiContext) get(ctx context.Context, url string) (*http.Response, error return nil, err } + // 更新速率限制信息 + a.rateLimit = parseRateLimit(response.Header) + + // 如果剩余请求数为 0, 等待到重置时间 + if a.rateLimit.Remaining == 0 { + waitTime := time.Until(a.rateLimit.Reset) + utils.Log.Warnf("rate limit exceeded, will wait for %s", waitTime) + return nil, ErrRateLimitExceeded + } + return response, nil } // GetReleases 获取仓库信息. -func (a *ApiContext) GetReleases(ctx context.Context, repo repository, perPage int) ([]model.Obj, error) { +func (a *APIContext) GetReleases(ctx context.Context, repo repository, perPage int) ([]model.Obj, error) { if perPage < 1 { perPage = 30 } @@ -170,10 +222,6 @@ func (a *ApiContext) GetReleases(ctx context.Context, repo repository, perPage i return nil, errors.Wrap(err, "failed to read response body") } - if response.StatusCode != http.StatusOK { - return nil, parseHTTPError(body) - } - releases := []Release{} err = utils.Json.Unmarshal(body, &releases) if err != nil { @@ -187,36 +235,49 @@ func (a *ApiContext) GetReleases(ctx context.Context, repo repository, perPage i return tree, nil } -// GetRelease 获取指定 tag 的 release. -func (a *ApiContext) GetRelease(ctx context.Context, repo repository, id int64) (*Release, error) { - url := fmt.Sprintf("https://api.github.com/repos/%s/releases/%d", repo.UrlEncode(), id) +// GetLatestRelease 获取最新 release. +func (a *APIContext) GetLatestRelease(ctx context.Context, repo repository) (model.Obj, error) { + url := fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", repo.UrlEncode()) response, err := a.getWithRetry(ctx, url) if err != nil { - return nil, err + var githubErr *GitHubError + if errors.As(err, &githubErr) && githubErr.StatusCode == http.StatusNotFound { + return nil, ErrNoRelease + } + return nil, errors.Wrap(err, "get latest release") } defer response.Body.Close() body, err := io.ReadAll(response.Body) if err != nil { - return nil, errors.Wrap(err, "failed to read response body") + return nil, errors.Wrap(err, "read response body") + } + + if response.StatusCode == http.StatusNotFound { + return nil, ErrNoRelease } if response.StatusCode != http.StatusOK { - return nil, parseHTTPError(body) + err := parseHTTPError(response.StatusCode, body) + var githubErr *GitHubError + if errors.As(err, &githubErr) && githubErr.StatusCode == http.StatusNotFound { + return nil, ErrNoRelease + } + return nil, err } - release := Release{} - err = utils.Json.Unmarshal(body, &release) - if err != nil { - return nil, errors.Wrap(err, "failed to unmarshal release") + var release Release + if err := utils.Json.Unmarshal(body, &release); err != nil { + return nil, errors.Wrap(err, "unmarshal release data") } + release.SetLatestFlag(true) return &release, nil } -// GetReleaseAsset 获取指定 tag 的 release 的 assets. -func (a *ApiContext) GetReleaseAsset(ctx context.Context, repo repository, ID int64) (*Asset, error) { - url := fmt.Sprintf("https://api.github.com/repos/%s/releases/assets/%d", repo.UrlEncode(), ID) +// GetRelease 获取指定 tag 的 release. +func (a *APIContext) GetRelease(ctx context.Context, repo repository, id int64) (*Release, error) { + url := fmt.Sprintf("https://api.github.com/repos/%s/releases/%d", repo.UrlEncode(), id) response, err := a.getWithRetry(ctx, url) if err != nil { return nil, err @@ -228,26 +289,18 @@ func (a *ApiContext) GetReleaseAsset(ctx context.Context, repo repository, ID in return nil, errors.Wrap(err, "failed to read response body") } - if response.StatusCode != http.StatusOK { - return nil, parseHTTPError(body) - } - - asset := Asset{} - err = utils.Json.Unmarshal(body, &asset) + release := Release{} + err = utils.Json.Unmarshal(body, &release) if err != nil { - return nil, errors.Wrap(err, "failed to unmarshal asset") + return nil, errors.Wrap(err, "failed to unmarshal release") } - return &asset, nil + return &release, nil } -var ( - ErrNoRelease = errors.New("no release found") -) - -// GetLatestRelease 获取最新 release. -func (a *ApiContext) GetLatestRelease(ctx context.Context, repo repository) (model.Obj, error) { - url := fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", repo.UrlEncode()) +// GetReleaseAsset 获取指定 tag 的 release 的 assets. +func (a *APIContext) GetReleaseAsset(ctx context.Context, repo repository, ID int64) (*Asset, error) { + url := fmt.Sprintf("https://api.github.com/repos/%s/releases/assets/%d", repo.UrlEncode(), ID) response, err := a.getWithRetry(ctx, url) if err != nil { return nil, err @@ -256,24 +309,18 @@ func (a *ApiContext) GetLatestRelease(ctx context.Context, repo repository) (mod body, err := io.ReadAll(response.Body) if err != nil { - return nil, errors.Wrap(err, "get latest release failed") - } - - if response.StatusCode != http.StatusOK { - if response.StatusCode == http.StatusNotFound { - // identify no release - return nil, ErrNoRelease - } - return nil, parseHTTPError(body) + return nil, errors.Wrap(err, "failed to read response body") } - release := Release{} - err = utils.Json.Unmarshal(body, &release) + asset := Asset{} + err = utils.Json.Unmarshal(body, &asset) if err != nil { - return nil, errors.Wrap(err, "get latest release failed") + return nil, errors.Wrap(err, "failed to unmarshal asset") } - release.SetLatestFlag(true) - - return &release, nil + return &asset, nil } + +var ( + ErrNoRelease = errors.New("no release found") +) diff --git a/drivers/github_release/github_test.go b/drivers/github_release/github_test.go new file mode 100644 index 00000000000..cc4b03dbff6 --- /dev/null +++ b/drivers/github_release/github_test.go @@ -0,0 +1,155 @@ +package github_release + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestParseRateLimit(t *testing.T) { + header := http.Header{} + header.Set("X-RateLimit-Limit", "60") + header.Set("X-RateLimit-Remaining", "59") + header.Set("X-RateLimit-Reset", "1735689600") // 2025-01-01 00:00:00 UTC + + rateLimit := parseRateLimit(header) + + assert.Equal(t, uint(60), rateLimit.Limit) + assert.Equal(t, uint(59), rateLimit.Remaining) + assert.Equal(t, time.Unix(1735689600, 0), rateLimit.Reset) +} + +func TestGitHubError(t *testing.T) { + err := &GitHubError{ + Message: "API rate limit exceeded", + StatusCode: 403, + } + + assert.Equal(t, "github api error: API rate limit exceeded (status: 403)", err.Error()) +} + +func TestNewAPIContext(t *testing.T) { + token := "test-token" + client := &http.Client{} + ctx := NewAPIContext(token, client) + + assert.Equal(t, token, ctx.token) + assert.Equal(t, GITHUB_API_VERSION, ctx.version) + assert.Equal(t, client, ctx.client) + assert.Equal(t, DEFAULT_TIMEOUT, ctx.defaultTimeout) +} + +func TestAPIContext_SetAuthHeader(t *testing.T) { + token := "test-token" + ctx := NewAPIContext(token, nil) + header := http.Header{} + + ctx.SetAuthHeader(header) + assert.Equal(t, "Bearer "+token, header.Get("Authorization")) +} + +func TestAPIContext_GetWithRetry_RateLimit(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-RateLimit-Limit", "60") + w.Header().Set("X-RateLimit-Remaining", "0") + w.Header().Set("X-RateLimit-Reset", "1735689600") + w.WriteHeader(http.StatusForbidden) + w.Write([]byte(`{"message": "API rate limit exceeded"}`)) + })) + defer server.Close() + + ctx := NewAPIContext("test-token", server.Client()) + _, err := ctx.getWithRetry(context.Background(), server.URL) + + assert.ErrorIs(t, err, ErrRateLimitExceeded) +} + +type testRoundTripper struct { + handler http.HandlerFunc +} + +func (t *testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + // 创建一个响应记录器 + w := httptest.NewRecorder() + // 调用处理函数 + t.handler.ServeHTTP(w, req) + // 将响应记录器转换为响应 + return w.Result(), nil +} + +func TestAPIContext_GetLatestRelease(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // 验证请求路径 + assert.Equal(t, "/repos/test-owner/test-repo/releases/latest", r.URL.Path) + + // 验证请求头 + assert.Equal(t, "Bearer test-token", r.Header.Get("Authorization")) + assert.Equal(t, "application/vnd.github+json", r.Header.Get("Accept")) + + // 设置速率限制头部 + w.Header().Set("X-RateLimit-Limit", "60") + w.Header().Set("X-RateLimit-Remaining", "59") + w.Header().Set("X-RateLimit-Reset", "1735689600") + + // 设置响应头和内容 + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{ + "id": 1, + "tag_name": "v1.0.0", + "name": "Release 1.0.0", + "published_at": "2025-01-01T00:00:00Z", + "created_at": "2025-01-01T00:00:00Z", + "assets": [] + }`)) + }) + + // 创建一个自定义的 HTTP 客户端 + client := &http.Client{ + Transport: &testRoundTripper{handler: handler}, + } + + ctx := NewAPIContext("test-token", client) + repo := repository{owner: "test-owner", name: "test-repo"} + release, err := ctx.GetLatestRelease(context.Background(), repo) + + if assert.NoError(t, err) { + assert.NotNil(t, release) + assert.Equal(t, "latest(v1.0.0)", release.GetName()) + } +} + +func TestAPIContext_GetLatestRelease_NoRelease(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // 验证请求路径 + assert.Equal(t, "/repos/test-owner/test-repo/releases/latest", r.URL.Path) + + // 验证请求头 + assert.Equal(t, "Bearer test-token", r.Header.Get("Authorization")) + assert.Equal(t, "application/vnd.github+json", r.Header.Get("Accept")) + + // 设置速率限制头部 + w.Header().Set("X-RateLimit-Limit", "60") + w.Header().Set("X-RateLimit-Remaining", "59") + w.Header().Set("X-RateLimit-Reset", "1735689600") + + // 返回 404 状态码 + w.WriteHeader(http.StatusNotFound) + w.Write([]byte(`{"message": "Not Found"}`)) + }) + + // 创建一个自定义的 HTTP 客户端 + client := &http.Client{ + Transport: &testRoundTripper{handler: handler}, + } + + ctx := NewAPIContext("test-token", client) + repo := repository{owner: "test-owner", name: "test-repo"} + _, err := ctx.GetLatestRelease(context.Background(), repo) + + assert.ErrorIs(t, err, ErrNoRelease) +}