diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..e0cb632 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,26 @@ +name: Lint + +on: + push: + branches: [master, main] + pull_request: + branches: [master, main] + +jobs: + golangci: + runs-on: ubuntu-latest + steps: + - name: Checkout Code + uses: actions/checkout@v4.2.2 + + - name: Install Go + uses: actions/setup-go@v5.4.0 + with: + go-version-file: go.mod + cache-dependency-path: go.sum + + - name: Lint + uses: golangci/golangci-lint-action@v8.0.0 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + version: v2.2.1 diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 9a103fd..36fa290 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -2,28 +2,27 @@ name: Testing on: push: - branches: [ master ] + branches: [master] pull_request: - branches: [ master ] + branches: [master] jobs: - build: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - name: Install Go + uses: actions/setup-go@v5 + with: + go-version: "1.20.x" + + - name: Checkout Code + uses: actions/checkout@v4 - - name: Set up Go - uses: actions/setup-go@v2 - with: - go-version: 1.18 + - name: Test + run: go test -v -timeout 30m -race ./... -coverprofile=coverage.txt -covermode=atomic - - name: Vet - run: go vet ./... - - name: Test - run: go test -v -timeout 30m -race ./... -coverprofile=coverage.txt -covermode=atomic - - - name: Upload coverage reports to Codecov - uses: codecov/codecov-action@v3 - env: - CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + - name: Upload coverage reports to Codecov + uses: codecov/codecov-action@v4 + with: + token: ${{ secrets.CODECOV_TOKEN }} + slug: wenchy/requests diff --git a/.golangci.yaml b/.golangci.yaml new file mode 100644 index 0000000..ea9ceb9 --- /dev/null +++ b/.golangci.yaml @@ -0,0 +1,35 @@ +--- +version: "2" +linters: + default: none + enable: + - errcheck + - govet + # - gosec + - ineffassign + - staticcheck + - unused + - misspell + exclusions: + # presets: + # - comments + # - common-false-positives + # - legacy + # - std-error-handling + rules: + - linters: + - staticcheck + text: "QF1008:" # could remove embedded field "XXX" from selector (staticcheck) + - linters: + - staticcheck + text: "QF1007:" # could merge conditional assignment into variable declaration (staticcheck) + - linters: + - staticcheck + text: "QF1001:" # could apply De Morgan's law (staticcheck) + - linters: + - staticcheck + text: "QF1006:" # could lift into loop condition (staticcheck) +formatters: + enable: + - gofmt + - gofumpt diff --git a/env.go b/env.go index 0eb9be8..875e637 100644 --- a/env.go +++ b/env.go @@ -12,15 +12,23 @@ type environment struct { // caches them for reuse by subsequent calls. It uses HTTP proxies as // directed by the environment variables HTTP_PROXY, HTTPS_PROXY and // NO_PROXY (or the lowercase versions thereof). - transport *http.Transport + transport *http.Transport + // hostRoundTrippers specify the host-specific RoundTripper to use for the + // request. If not found, http.DefaultTransport is used. + hostRoundTrippers map[string]http.RoundTripper + // interceptor intercepts each HTTP request. interceptor InterceptorFunc } var env environment func init() { - env.timeout = 60 * time.Second // default timeout - env.transport = http.DefaultTransport.(*http.Transport).Clone() // default transport + env.timeout = 60 * time.Second // default timeout + transport, ok := http.DefaultTransport.(*http.Transport) + if !ok { + panic("Ooh! http.DefaultTransport's underlying is not *http.Transport. Maybe golang team has changed it.") + } + env.transport = transport } // SetEnvTimeout sets the default timeout for each HTTP request at @@ -53,7 +61,7 @@ func ChainInterceptors(interceptors ...InterceptorFunc) InterceptorFunc { } } -// getChainDo recursively generates the chained do. +// getChainDo generates the chained do recursively. func getChainDo(interceptors []InterceptorFunc, curr int, finalDo Do) Do { if curr == len(interceptors)-1 { return finalDo @@ -62,3 +70,15 @@ func getChainDo(interceptors []InterceptorFunc, curr int, finalDo Do) Do { return interceptors[curr+1](ctx, r, getChainDo(interceptors, curr+1, finalDo)) } } + +// SetHostTransport sets the host-specific RoundTripper to use for the request. +// +// # Example +// +// SetHostTransport(map[string]http.RoundTripper{ +// "example1.com": http.DefaultTransport, +// "example2.com": MyCustomTransport, +// }) +func SetHostTransport(rts map[string]http.RoundTripper) { + env.hostRoundTrippers = rts +} diff --git a/go.mod b/go.mod index aa4b1d3..be5cabb 100644 --- a/go.mod +++ b/go.mod @@ -1,10 +1,11 @@ module github.com/Wenchy/requests -go 1.18 +go 1.20 + +require github.com/stretchr/testify v1.9.0 require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/stretchr/testify v1.9.0 gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/options.go b/options.go index 752d2e4..870c432 100644 --- a/options.go +++ b/options.go @@ -42,13 +42,32 @@ type Options struct { DumpRequestOut *string DumpResponse *string - // custom interceptor + // interceptor Interceptor InterceptorFunc + // round tripper + RoundTripper http.RoundTripper } // Option is the functional option type. type Option func(*Options) +// newDefaultOptions creates a new default HTTP options. +func newDefaultOptions() *Options { + return &Options{ + Headers: http.Header{}, + bodyType: bodyTypeDefault, + Timeout: env.timeout, + } +} + +func parseOptions(options ...Option) *Options { + opts := newDefaultOptions() + for _, setter := range options { + setter(opts) + } + return opts +} + // Context sets the HTTP request context. // // For outgoing client request, the context controls the entire lifetime of @@ -351,27 +370,19 @@ func Dump(req, resp *string) Option { } } -// Interceptor specifies a custom interceptor, which is prepended to environment -// interceptors for current request only. +// Interceptor prepends an interceptor to environment interceptors for current +// request only. func Interceptor(interceptor InterceptorFunc) Option { return func(opts *Options) { opts.Interceptor = interceptor } } -// newDefaultOptions creates a new default HTTP options. -func newDefaultOptions() *Options { - return &Options{ - Headers: http.Header{}, - bodyType: bodyTypeDefault, - Timeout: env.timeout, - } -} - -func parseOptions(options ...Option) *Options { - opts := newDefaultOptions() - for _, setter := range options { - setter(opts) +// Transport specifies a custom RoundTripper for current request only. +// +// NOTE: If specified, then option DisableKeepAlives() will not work. +func Transport(rt http.RoundTripper) Option { + return func(opts *Options) { + opts.RoundTripper = rt } - return opts } diff --git a/request.go b/request.go index 3dcc55f..be35142 100644 --- a/request.go +++ b/request.go @@ -113,18 +113,30 @@ func do(method, url string, opts *Options, body []byte) (*Response, error) { // // - https://stackoverflow.com/questions/57683132/turning-off-connection-pool-for-go-http-client // - https://stackoverflow.com/questions/59656164/what-is-the-difference-between-net-dialerkeepalive-and-http-transportidletimeo - transport := env.transport - if opts.DisableKeepAlives { - // If option DisableKeepAlives set as true, then clone a new transport - // just for this one-off HTTP request. - transport = env.transport.Clone() - transport.DisableKeepAlives = true + var roundTripper http.RoundTripper + if opts.RoundTripper != nil { + roundTripper = opts.RoundTripper + } else { + if rt := env.hostRoundTrippers[req.Host]; rt != nil { + // Use the host-specific RoundTripper if set. + roundTripper = rt + } else if opts.DisableKeepAlives { + // If option DisableKeepAlives set as true, then clone a new transport + // just for this one-off HTTP request. + transport := env.transport.Clone() + transport.DisableKeepAlives = true + roundTripper = transport + } else { + // If option DisableKeepAlives not set as true, then use the default + // transport. + roundTripper = env.transport + } } client := &Client{ Client: &http.Client{ CheckRedirect: redirector.RedirectPolicyFunc, Timeout: opts.Timeout, - Transport: transport, + Transport: roundTripper, }, } var ctx context.Context diff --git a/request_test.go b/request_test.go index b5931c8..9b11149 100644 --- a/request_test.go +++ b/request_test.go @@ -5,7 +5,6 @@ import ( "crypto/md5" "encoding/hex" "encoding/json" - "fmt" "io" "log" "net/http" @@ -19,7 +18,7 @@ import ( "testing" "time" - "github.com/stretchr/testify/require" + "github.com/stretchr/testify/assert" ) func init() { @@ -126,7 +125,7 @@ func TestGet(t *testing.T) { return } if err == nil { - fmt.Printf("response body: %+v\n", got.Text()) + t.Logf("response body: %+v\n", got.Text()) } }) } @@ -180,7 +179,7 @@ func TestGetWithContext(t *testing.T) { return } if err == nil { - fmt.Printf("response body: %+v\n", got.Text()) + t.Logf("response body: %+v\n", got.Text()) } }) } @@ -189,12 +188,14 @@ func TestGetWithContext(t *testing.T) { func TestPostBody(t *testing.T) { testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { body, err := io.ReadAll(r.Body) - defer r.Body.Close() - if err != nil { - t.Errorf("ReadAll failed: %v", err) - } + assert.NoError(t, err) + defer func() { + assert.NoError(t, r.Body.Close()) + }() w.WriteHeader(http.StatusOK) - w.Write(body) + n, err := w.Write(body) + assert.NoError(t, err) + assert.Equal(t, n, len(body)) })) defer testServer.Close() @@ -241,12 +242,14 @@ func TestPostBody(t *testing.T) { func TestPostData(t *testing.T) { testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { body, err := io.ReadAll(r.Body) - defer r.Body.Close() - if err != nil { - t.Errorf("ReadAll failed: %v", err) - } + assert.NoError(t, err) + defer func() { + assert.NoError(t, r.Body.Close()) + }() w.WriteHeader(http.StatusOK) - w.Write(body) + n, err := w.Write(body) + assert.NoError(t, err) + assert.Equal(t, n, len(body)) })) defer testServer.Close() @@ -293,7 +296,7 @@ func TestPostData(t *testing.T) { func TestPostForm(t *testing.T) { testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { err := r.ParseForm() - require.NoError(t, err) + assert.NoError(t, err) req := testRequest{ Headers: r.Header, Params: r.URL.Query(), @@ -303,9 +306,10 @@ func TestPostForm(t *testing.T) { t.Logf("headers: %v", r.Header) w.WriteHeader(http.StatusOK) data, err := json.Marshal(req) - require.NoError(t, err) - _, err = w.Write(data) - require.NoError(t, err) + assert.NoError(t, err) + n, err := w.Write(data) + assert.NoError(t, err) + assert.Equal(t, n, len(data)) })) defer testServer.Close() type args struct { @@ -386,11 +390,11 @@ func TestPostForm(t *testing.T) { if err == nil && tt.want != nil { rsp := &testRequest{} err := got.JSON(rsp) - require.NoError(t, err) - require.Subsetf(t, rsp.Headers, tt.want.Headers, "some headers missing in HTTP server-side") - require.Subsetf(t, rsp.Params, tt.want.Params, "some params missing in HTTP server-side") - require.Subsetf(t, rsp.Form, tt.want.Form, "some form data missing in HTTP server-side") - fmt.Printf("got testRequest: %+v\n", rsp) + assert.NoError(t, err) + assert.Subsetf(t, rsp.Headers, tt.want.Headers, "some headers missing in HTTP server-side") + assert.Subsetf(t, rsp.Params, tt.want.Params, "some params missing in HTTP server-side") + assert.Subsetf(t, rsp.Form, tt.want.Form, "some form data missing in HTTP server-side") + t.Logf("got testRequest: %+v\n", rsp) } }) } @@ -415,26 +419,24 @@ func TestPostJSON(t *testing.T) { t.Logf("headers: %v", r.Header) body, err := io.ReadAll(r.Body) - if err != nil { - t.Errorf("ReadAll failed: %v", err) - } - defer r.Body.Close() + assert.NoError(t, err) + defer func() { + assert.NoError(t, r.Body.Close()) + }() var req EchoRequest - if err := json.Unmarshal(body, &req); err != nil { - t.Errorf("json unmarshal failed: %v", err) - } + err = json.Unmarshal(body, &req) + assert.NoError(t, err) jsonResp := &EchoResponse{ ID: req.ID, Name: "echo " + req.Name, } - resBytes, err := json.Marshal(jsonResp) - if err != nil { - t.Errorf("json marshal failed: %v", err) - return - } + respBytes, err := json.Marshal(jsonResp) + assert.NoError(t, err) w.WriteHeader(http.StatusOK) - w.Write(resBytes) + n, err := w.Write(respBytes) + assert.NoError(t, err) + assert.Equal(t, n, len(respBytes)) })) defer testServer.Close() @@ -505,15 +507,17 @@ func TestPostFiles(t *testing.T) { // Go 1.17: net/http: multipart form should not include directory path in filename // Refer: https://github.com/golang/go/issues/45789 file, header, err := r.FormFile(formKey) - require.NoErrorf(t, err, "get form file: %s failed", formKey) - defer file.Close() + assert.NoErrorf(t, err, "get form file: %s failed", formKey) + defer func() { + assert.NoError(t, file.Close()) + }() got, err := io.ReadAll(file) - require.NoError(t, err) + assert.NoError(t, err) path := filepath.Join("./testdata/", header.Filename) src, err := os.ReadFile(path) - require.NoError(t, err) + assert.NoError(t, err) - require.Equalf(t, string(src), string(got), "content not same: %s", formKey) + assert.Equalf(t, string(src), string(got), "content not same: %s", formKey) return nil } @@ -525,28 +529,32 @@ func TestPostFiles(t *testing.T) { if err := handleUpload("file2"); err != nil { w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte("upload form: file2 failed")) + failedMsg := "upload form: file2 failed" + n, err := w.Write([]byte(failedMsg)) + assert.NoError(t, err) + assert.Equal(t, n, len(failedMsg)) return } w.WriteHeader(http.StatusOK) - w.Write([]byte("upload file success")) + successMsg := "upload file success" + n, err := w.Write([]byte(successMsg)) + assert.NoError(t, err) + assert.Equal(t, n, len(successMsg)) })) defer testServer.Close() fh1, err := os.Open(filename1) - if err != nil { - t.Errorf("open file: %s failed: %+v", filename1, err) - return - } - defer fh1.Close() + assert.NoError(t, err) + defer func() { + assert.NoError(t, fh1.Close()) + }() fh2, err := os.Open(filename2) - if err != nil { - t.Errorf("open file: %s failed: %+v", filename2, err) - return - } - defer fh2.Close() + assert.NoError(t, err) + defer func() { + assert.NoError(t, fh2.Close()) + }() type args struct { url string @@ -604,17 +612,13 @@ func TestPostFiles(t *testing.T) { func TestPatch(t *testing.T) { testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPatch { - t.Errorf("method is not PATCH: %s", r.Method) - } - b, err := io.ReadAll(r.Body) - if err != nil { - t.Errorf("read body failed: %+v", err) - } + assert.Equalf(t, http.MethodPatch, r.Method, "method is not PATCH: %s", r.Method) + body, err := io.ReadAll(r.Body) + assert.NoError(t, err) w.WriteHeader(http.StatusOK) - if _, err := w.Write(b); err != nil { - t.Errorf("write response failed: %+v", err) - } + n, err := w.Write(body) + assert.NoError(t, err) + assert.Equal(t, n, len(body)) })) defer testServer.Close() type args struct { @@ -676,28 +680,27 @@ func TestInterceptors(t *testing.T) { filename2 := "./testdata/file2.txt" testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { body, err := io.ReadAll(r.Body) - defer r.Body.Close() - if err != nil { - t.Errorf("ReadAll failed: %v", err) - } - require.Equalf(t, strconv.Itoa(len(body)), r.Header.Get("X-Body-Size"), "content length not same") - require.Equalf(t, hex.EncodeToString(md5.New().Sum(body)), r.Header.Get("X-Body-Md5"), "content md5 not same") + assert.NoError(t, err) + defer func() { + assert.NoError(t, r.Body.Close()) + }() + + assert.Equalf(t, strconv.Itoa(len(body)), r.Header.Get("X-Body-Size"), "content length not same") + assert.Equalf(t, hex.EncodeToString(md5.New().Sum(body)), r.Header.Get("X-Body-Md5"), "content md5 not same") })) defer testServer.Close() fh1, err := os.Open(filename1) - if err != nil { - t.Errorf("open file: %s failed: %+v", filename1, err) - return - } - defer fh1.Close() + assert.NoError(t, err) + defer func() { + assert.NoError(t, fh1.Close()) + }() fh2, err := os.Open(filename2) - if err != nil { - t.Errorf("open file: %s failed: %+v", filename2, err) - return - } - defer fh2.Close() + assert.NoError(t, err) + defer func() { + assert.NoError(t, fh2.Close()) + }() type args struct { url string @@ -788,3 +791,100 @@ func TestInterceptors(t *testing.T) { }) } } + +type CustomTransport struct { + *http.Transport +} + +func (ct CustomTransport) RoundTrip(req *http.Request) (*http.Response, error) { + req.Header.Set("X-Transport", "CustomTransport") + return ct.Transport.RoundTrip(req) +} + +func TestSetHostTransport(t *testing.T) { + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + body := r.Header.Get("X-Transport") + n, err := w.Write([]byte(body)) + assert.NoError(t, err) + assert.Equal(t, n, len(body)) + })) + defer testServer.Close() + testServer1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer testServer1.Close() + testServer2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + body := r.Header.Get("X-Transport") + "(testServer2)" + n, err := w.Write([]byte(body)) + assert.NoError(t, err) + assert.Equal(t, n, len(body)) + })) + defer testServer.Close() + // Parse the URL string + serverURL, err := url.Parse(testServer.URL) + assert.NoError(t, err) + + defaultTransport, ok := http.DefaultTransport.(*http.Transport) + assert.Equal(t, true, ok) + + trans := defaultTransport.Clone() + trans.DisableKeepAlives = true + trans.MaxIdleConns = 1 + trans.IdleConnTimeout = 10 * time.Second + customTransport := CustomTransport{ + Transport: trans, + } + + SetHostTransport(map[string]http.RoundTripper{ + serverURL.Host: customTransport, + }) + type args struct { + url string + options []Option + } + tests := []struct { + name string + args args + wantErr bool + wantBody string + }{ + { + name: "hit host transport at env level", + args: args{ + url: testServer.URL, + }, + wantBody: "CustomTransport", + }, + { + name: "miss host transport at env level", + args: args{ + url: testServer1.URL, + }, + wantBody: "", + }, + { + name: "host transport for current request", + args: args{ + url: testServer2.URL, + options: []Option{ + Transport(customTransport), + }, + }, + wantBody: "CustomTransport(testServer2)", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := Get(tt.args.url, tt.args.options...) + if (err != nil) != tt.wantErr { + t.Errorf("Get() error = %v, wantErr %v", err, tt.wantErr) + return + } + if err == nil { + assert.Equal(t, tt.wantBody, string(got.Bytes())) + } + }) + } +} diff --git a/response.go b/response.go index 58de83c..7f95c45 100644 --- a/response.go +++ b/response.go @@ -42,14 +42,13 @@ func newResponse(resp *http.Response, opts *Options) (*Response, error) { } // readAndCloseBody drains all the HTTP response body stream and then closes it. -func (r *Response) readAndCloseBody() error { - defer r.Response.Body.Close() - var err error +func (r *Response) readAndCloseBody() (err error) { + defer func() { + err1 := r.Response.Body.Close() + err = errors.Join(err, err1) + }() r.body, err = io.ReadAll(r.Response.Body) - if err != nil { - return err - } - return nil + return err } // StatusCode returns status code of HTTP response. @@ -66,14 +65,14 @@ func (r *Response) StatusCode() int { // StatusText returns a text for the HTTP status code. // // NOTE: -// - It returns "Response is nil" if response is nil. +// - It returns "" if response is nil. // - It returns the empty string if the code is unknown. // // e.g. "OK" func (r *Response) StatusText() string { if r == nil || r.Response == nil { // return special status code -1 which is not registered with IANA. - return "Response is nil" + return "" } return r.Response.Status } @@ -111,9 +110,9 @@ func (r *Response) Headers() http.Header { // Cookies parses and returns the cookies set in the Set-Cookie headers. func (r *Response) Cookies() map[string]*http.Cookie { - m := make(map[string]*http.Cookie) + cookies := make(map[string]*http.Cookie) for _, c := range r.Response.Cookies() { - m[c.Name] = c + cookies[c.Name] = c } - return m + return cookies }