From 0f6b4e3cbfcfe71e8dc94e554c420abac79bef31 Mon Sep 17 00:00:00 2001
From: cubatic45 <148538725+cubatic45@users.noreply.github.com>
Date: Wed, 6 Mar 2024 01:18:05 +0000
Subject: [PATCH] update token when http status_code ne 200

---
 copilot.go |  6 ++++++
 proxy.go   | 14 +++++++++++---
 2 files changed, 17 insertions(+), 3 deletions(-)

diff --git a/copilot.go b/copilot.go
index 68a4bcb..275a729 100644
--- a/copilot.go
+++ b/copilot.go
@@ -77,6 +77,12 @@ func Init() {
 // token return the access token for github copilot
 type copiloter interface {
 	token() (string, error)
+	refresh() (string, error)
+}
+
+func (c *copilot) refresh() (string, error) {
+	caches.Delete(c.GithubCom.OauthToken)
+	return c.token()
 }
 
 func (c *copilot) token() (string, error) {
diff --git a/proxy.go b/proxy.go
index 9285e35..0d90020 100644
--- a/proxy.go
+++ b/proxy.go
@@ -7,6 +7,7 @@ import (
 	"fmt"
 	"io"
 	"net/http"
+	"os"
 	"strings"
 	"time"
 
@@ -34,10 +35,11 @@ type Proxy struct {
 }
 
 func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+	write := io.MultiWriter(w, os.Stdout)
 	req, err := http.NewRequest(r.Method, r.URL.String(), r.Body)
 	if err != nil {
 		w.WriteHeader(http.StatusInternalServerError)
-		fmt.Fprintf(w, "http request error: %v", err)
+		fmt.Fprintf(write, "http request error: %v\n", err)
 	}
 	if p.Director != nil {
 		p.Director(req)
@@ -45,12 +47,15 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 	resp, err := http.DefaultClient.Do(req)
 	if err != nil {
 		w.WriteHeader(http.StatusInternalServerError)
-		fmt.Fprintf(w, "http client error: %v", err)
+		fmt.Fprintf(write, "http client error: %v\n", err)
 		return
 	}
+	defer resp.Body.Close()
 	if resp.StatusCode != http.StatusOK {
 		w.WriteHeader(resp.StatusCode)
-		fmt.Fprintf(w, "http status error: %d", resp.StatusCode)
+		body, _ := io.ReadAll(resp.Body)
+		fmt.Fprintf(write, "http status error: %d\nbody: %s\n", resp.StatusCode, string(body))
+		go getCopilot().refresh()
 		return
 	}
 	if p.stream {
@@ -133,6 +138,9 @@ func handleProxy(w http.ResponseWriter, r *http.Request) {
 		req.URL.Path = strings.ReplaceAll(req.URL.Path, "/v1/", "/")
 
 		accToken, err := getCopilot().token()
+		if err != nil {
+			accToken, _ = getCopilot().refresh()
+		}
 		if accToken == "" {
 			w.WriteHeader(http.StatusInternalServerError)
 			fmt.Fprintf(w, "get acc token error: %v", err)