From 7cb449439516ed4ecf305d5e6cefddbf31b7a9a5 Mon Sep 17 00:00:00 2001 From: cubatic45 Date: Wed, 6 Mar 2024 14:20:44 +0800 Subject: [PATCH] update vscode machineid and sessionid store --- copilot.go | 41 +++++++++++++++++++++++++++++------------ header.go | 16 ++++------------ proxy.go | 2 +- 3 files changed, 34 insertions(+), 25 deletions(-) diff --git a/copilot.go b/copilot.go index 275a729..1cc5ade 100644 --- a/copilot.go +++ b/copilot.go @@ -1,6 +1,8 @@ package main import ( + "crypto/sha256" + "encoding/hex" "encoding/json" "fmt" "io" @@ -11,6 +13,7 @@ import ( "sync" "time" + "github.com/google/uuid" "github.com/mitchellh/go-homedir" "github.com/patrickmn/go-cache" "github.com/tidwall/gjson" @@ -35,6 +38,12 @@ type copilot struct { } `json:"github.com"` } +type vscodeCopilot struct { + token string + machineid string + sessionid string +} + func Init() { // init cache one := &sync.Once{} @@ -76,18 +85,18 @@ func Init() { // Copiloter is the interface that wraps the token method. // token return the access token for github copilot type copiloter interface { - token() (string, error) - refresh() (string, error) + token() (*vscodeCopilot, error) + refresh() (*vscodeCopilot, error) } -func (c *copilot) refresh() (string, error) { +func (c *copilot) refresh() (*vscodeCopilot, error) { caches.Delete(c.GithubCom.OauthToken) return c.token() } -func (c *copilot) token() (string, error) { +func (c *copilot) token() (*vscodeCopilot, error) { if cacheToken, ok := caches.Get(c.GithubCom.OauthToken); ok { - return cacheToken.(string), nil + return cacheToken.(*vscodeCopilot), nil } tokenURL := c.GithubCom.DevOverride.CopilotTokenURL if tokenURL == "" { @@ -96,27 +105,35 @@ func (c *copilot) token() (string, error) { req, err := http.NewRequest(http.MethodGet, tokenURL, nil) if err != nil { - return "", err + return nil, err } req.Header.Set("Authorization", fmt.Sprintf("token %s", c.GithubCom.OauthToken)) resp, err := http.DefaultClient.Do(req) if err != nil { - return "", err + return nil, err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return "", fmt.Errorf("get token error: %d", resp.StatusCode) + return nil, fmt.Errorf("get token error: %d", resp.StatusCode) } body, err := io.ReadAll(resp.Body) if err != nil { - return "", err + return nil, err } if token := gjson.GetBytes(body, "token").String(); token != "" { - caches.Set(c.GithubCom.OauthToken, token, 14*time.Minute) - return token, nil + sessionId := fmt.Sprintf("%s%d", uuid.New().String(), time.Now().UnixNano()/int64(time.Millisecond)) + machineID := sha256.Sum256([]byte(uuid.New().String())) + machineIDStr := hex.EncodeToString(machineID[:]) + vscodeCopilot := &vscodeCopilot{ + token: token, + machineid: machineIDStr, + sessionid: sessionId, + } + caches.Set(c.GithubCom.OauthToken, vscodeCopilot, 14*time.Minute) + return vscodeCopilot, nil } - return "", fmt.Errorf("get token error") + return nil, fmt.Errorf("get token error") } diff --git a/header.go b/header.go index eb88af4..1f84184 100644 --- a/header.go +++ b/header.go @@ -1,25 +1,17 @@ package main import ( - "crypto/sha256" - "encoding/hex" - "fmt" - "time" - "github.com/google/uuid" ) -func getAccHeaders(accessToken string) map[string]string { - sessionId := fmt.Sprintf("%s%d", uuid.New().String(), time.Now().UnixNano()/int64(time.Millisecond)) - machineID := sha256.Sum256([]byte(uuid.New().String())) - machineIDStr := hex.EncodeToString(machineID[:]) +func getAccHeaders(accessToken *vscodeCopilot) map[string]string { return map[string]string{ "Host": "api.githubcopilot.com", - "Authorization": "Bearer " + accessToken, + "Authorization": "Bearer " + accessToken.token, "X-Request-Id": uuid.New().String(), "X-Github-Api-Version": "2023-07-07", - "Vscode-Sessionid": sessionId, - "Vscode-machineid": machineIDStr, + "Vscode-Sessionid": accessToken.sessionid, + "Vscode-machineid": accessToken.machineid, "Editor-Version": "vscode/1.85.1", "Editor-Plugin-Version": "copilot-chat/0.11.1", "Openai-Organization": "github-copilot", diff --git a/proxy.go b/proxy.go index 0d90020..84a5b77 100644 --- a/proxy.go +++ b/proxy.go @@ -141,7 +141,7 @@ func handleProxy(w http.ResponseWriter, r *http.Request) { if err != nil { accToken, _ = getCopilot().refresh() } - if accToken == "" { + if accToken == nil { w.WriteHeader(http.StatusInternalServerError) fmt.Fprintf(w, "get acc token error: %v", err) return