diff --git a/client.go b/client.go index 5c967ba..831274e 100644 --- a/client.go +++ b/client.go @@ -4,23 +4,57 @@ import ( "context" "net/http" "net/http/httputil" + "time" ) -// Do is called by Interceptor to complete HTTP requests. -type Do func(ctx context.Context, r *Request) (*Response, error) +// ClientOption is the functional option type. +type ClientOption func(*Client) -// InterceptorFunc provides a hook to intercept the execution of an HTTP request -// invocation. When an interceptor(s) is set, requests delegates all HTTP -// client invocations to the interceptor, and it is the responsibility of the -// interceptor to call do to complete the processing of the HTTP request. -type InterceptorFunc func(ctx context.Context, r *Request, do Do) (*Response, error) +// WithTimeout specifies a time limit for each client request. +// +// A Timeout of zero means no timeout. Default is zero. +func WithTimeout(timeout time.Duration) ClientOption { + return func(c *Client) { + c.client.Timeout = timeout + } +} +// WithTransport specifies a transport for client. +func WithTransport(transport http.RoundTripper) ClientOption { + return func(c *Client) { + c.client.Transport = transport + } +} + +// WithInterceptor specifies an interceptor for client. +// You can use [ChainInterceptors] to chain multiple interceptors into one. +func WithInterceptor(interceptor InterceptorFunc) ClientOption { + return func(c *Client) { + c.interceptor = interceptor + } +} + +// Client is an HTTP client which wraps around [http.Client] for elegant APIs and easy use. type Client struct { - *http.Client + client *http.Client + interceptor InterceptorFunc +} + +// NewClient creates a new client to serve HTTP requests. +func NewClient(setters ...ClientOption) *Client { + client := newDefaultClient() + for _, setter := range setters { + setter(client) + } + return client } -// Do sends the HTTP request and returns after response is received. -func (c *Client) Do(ctx context.Context, r *Request) (*Response, error) { +// request is the common func to send an HTTP request. +func (c *Client) request(method, url string, opts *Options, body []byte) (*Response, error) { + r, err := newRequest(method, url, opts, body) + if err != nil { + return nil, err + } if r.opts.DumpRequestOut != nil { reqDump, err := httputil.DumpRequestOut(r.Request, true) if err != nil { @@ -28,15 +62,18 @@ func (c *Client) Do(ctx context.Context, r *Request) (*Response, error) { } *r.opts.DumpRequestOut = string(reqDump) } - if ctx != nil { - r = r.WithContext(ctx) + ctx := opts.ctx + if opts.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, opts.Timeout) + defer cancel() } var interceptors []InterceptorFunc if r.opts.Interceptor != nil { interceptors = append(interceptors, r.opts.Interceptor) } - if env.interceptor != nil { - interceptors = append(interceptors, env.interceptor) + if c.interceptor != nil { + interceptors = append(interceptors, c.interceptor) } interceptor := ChainInterceptors(interceptors...) if interceptor != nil { @@ -45,10 +82,12 @@ func (c *Client) Do(ctx context.Context, r *Request) (*Response, error) { return c.do(ctx, r) } +// do sends an HTTP request and returns an HTTP response, following policy +// (such as redirects, cookies, auth) as configured on the client. func (c *Client) do(ctx context.Context, r *Request) (*Response, error) { // If the returned error is nil, the Response will contain // a non-nil Body which the user is expected to close. - resp, err := c.Client.Do(r.Request) + resp, err := c.client.Do(r.Request) if err != nil { return nil, err } @@ -62,3 +101,45 @@ func (c *Client) do(ctx context.Context, r *Request) (*Response, error) { return newResponse(resp, r.opts) } + +// Get sends an HTTP request with GET method. +// +// On error, any Response can be ignored. A non-nil Response with a +// non-nil error only occurs when Response.StatusCode() is not 2xx. +func (c *Client) Get(url string, options ...Option) (*Response, error) { + return c.callMethod(http.MethodGet, url, options...) +} + +// Post sends an HTTP POST request. +func (c *Client) Post(url string, options ...Option) (*Response, error) { + return c.callMethod(http.MethodPost, url, options...) +} + +// Put sends an HTTP request with PUT method. +// +// On error, any Response can be ignored. A non-nil Response with a +// non-nil error only occurs when Response.StatusCode() is not 2xx. +func (c *Client) Put(url string, options ...Option) (*Response, error) { + return c.callMethod(http.MethodPut, url, options...) +} + +// Patch sends an HTTP request with PATCH method. +// +// On error, any Response can be ignored. A non-nil Response with a +// non-nil error only occurs when Response.StatusCode() is not 2xx. +func (c *Client) Patch(url string, options ...Option) (*Response, error) { + return c.callMethod(http.MethodPatch, url, options...) +} + +// Delete sends an HTTP request with DELETE method. +// +// On error, any Response can be ignored. A non-nil Response with a +// non-nil error only occurs when Response.StatusCode() is not 2xx. +func (c *Client) Delete(url string, options ...Option) (*Response, error) { + return c.callMethod(http.MethodDelete, url, options...) +} + +func (c *Client) callMethod(method, url string, options ...Option) (*Response, error) { + opts := parseOptions(options...) + return dispatchers[opts.bodyType](c, method, url, opts) +} diff --git a/client_test.go b/client_test.go new file mode 100644 index 0000000..2b4637c --- /dev/null +++ b/client_test.go @@ -0,0 +1,79 @@ +package requests + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestClientOption(t *testing.T) { + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Logf("query strings: %v", r.URL.Query()) + t.Logf("headers: %v", r.Header) + time.Sleep(100 * time.Millisecond) + w.WriteHeader(http.StatusOK) + })) + defer testServer.Close() + type args struct { + url string + options []ClientOption + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "with timeout", + args: args{ + url: testServer.URL, + options: []ClientOption{ + WithTimeout(time.Millisecond), + }, + }, + wantErr: true, + }, + { + name: "with transport", + args: args{ + url: testServer.URL, + options: []ClientOption{ + WithTransport(func() http.RoundTripper { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.DisableKeepAlives = true + return transport + }()), + }, + }, + wantErr: false, + }, + { + name: "with interceptor", + args: args{ + url: testServer.URL, + options: []ClientOption{ + WithInterceptor(func(ctx context.Context, r *Request, do Do) (*Response, error) { + t.Logf("method: %s", r.Method) + return do(ctx, r) + }), + }, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cli := NewClient(tt.args.options...) + got, err := cli.Get(tt.args.url) + if (err != nil) != tt.wantErr { + t.Errorf("Get() error = %v, wantErr %v", err, tt.wantErr) + return + } + if err == nil { + t.Logf("response body: %+v\n", got.Text()) + } + }) + } +} diff --git a/default.go b/default.go new file mode 100644 index 0000000..9902a93 --- /dev/null +++ b/default.go @@ -0,0 +1,38 @@ +package requests + +import ( + "net/http" + "sync" + "time" + + "github.com/Wenchy/requests/internal/auth/redirector" +) + +var ( + once sync.Once + defaultClient *Client +) + +func newDefaultClient() *Client { + return &Client{ + client: &http.Client{ + CheckRedirect: redirector.RedirectPolicyFunc, + Timeout: 10 * time.Second, + }, + } +} + +func getDefaultClient() *Client { + once.Do(func() { + defaultClient = newDefaultClient() + }) + return defaultClient +} + +// InitDefaultClient initializes the default client with given options. +func InitDefaultClient(setters ...ClientOption) { + client := getDefaultClient() + for _, setter := range setters { + setter(client) + } +} diff --git a/env.go b/env.go deleted file mode 100644 index 875e637..0000000 --- a/env.go +++ /dev/null @@ -1,84 +0,0 @@ -package requests - -import ( - "context" - "net/http" - "time" -) - -type environment struct { - timeout time.Duration - // transport establishes network connections as needed and - // caches them for reuse by subsequent calls. It uses HTTP proxies as - // directed by the environment variables HTTP_PROXY, HTTPS_PROXY and - // NO_PROXY (or the lowercase versions thereof). - transport *http.Transport - // hostRoundTrippers specify the host-specific RoundTripper to use for the - // request. If not found, http.DefaultTransport is used. - hostRoundTrippers map[string]http.RoundTripper - // interceptor intercepts each HTTP request. - interceptor InterceptorFunc -} - -var env environment - -func init() { - env.timeout = 60 * time.Second // default timeout - transport, ok := http.DefaultTransport.(*http.Transport) - if !ok { - panic("Ooh! http.DefaultTransport's underlying is not *http.Transport. Maybe golang team has changed it.") - } - env.transport = transport -} - -// SetEnvTimeout sets the default timeout for each HTTP request at -// the environment level. -func SetEnvTimeout(timeout time.Duration) { - env.timeout = timeout -} - -// WithInterceptor specifies the interceptor for each HTTP request. -func WithInterceptor(interceptors ...InterceptorFunc) { - // Prepend env.interceptor to the chaining interceptors if it exists, since - // env.interceptor will be executed before any other chained interceptors. - if env.interceptor != nil { - interceptors = append([]InterceptorFunc{env.interceptor}, interceptors...) - } - env.interceptor = ChainInterceptors(interceptors...) -} - -// ChainInterceptors chains multiple interceptors into one. -func ChainInterceptors(interceptors ...InterceptorFunc) InterceptorFunc { - switch len(interceptors) { - case 0: - return nil - case 1: - return interceptors[0] - default: - return func(ctx context.Context, r *Request, do Do) (*Response, error) { - return interceptors[0](ctx, r, getChainDo(interceptors, 0, do)) - } - } -} - -// getChainDo generates the chained do recursively. -func getChainDo(interceptors []InterceptorFunc, curr int, finalDo Do) Do { - if curr == len(interceptors)-1 { - return finalDo - } - return func(ctx context.Context, r *Request) (*Response, error) { - return interceptors[curr+1](ctx, r, getChainDo(interceptors, curr+1, finalDo)) - } -} - -// SetHostTransport sets the host-specific RoundTripper to use for the request. -// -// # Example -// -// SetHostTransport(map[string]http.RoundTripper{ -// "example1.com": http.DefaultTransport, -// "example2.com": MyCustomTransport, -// }) -func SetHostTransport(rts map[string]http.RoundTripper) { - env.hostRoundTrippers = rts -} diff --git a/interceptor.go b/interceptor.go new file mode 100644 index 0000000..7558623 --- /dev/null +++ b/interceptor.go @@ -0,0 +1,38 @@ +package requests + +import ( + "context" +) + +// Do is called by Interceptor to complete HTTP requests. +type Do func(ctx context.Context, r *Request) (*Response, error) + +// InterceptorFunc provides a hook to intercept the execution of an HTTP request +// invocation. When an interceptor(s) is set, requests delegates all HTTP +// client invocations to the interceptor, and it is the responsibility of the +// interceptor to call do to complete the processing of the HTTP request. +type InterceptorFunc func(ctx context.Context, r *Request, do Do) (*Response, error) + +// ChainInterceptors chains multiple interceptors into one. +func ChainInterceptors(interceptors ...InterceptorFunc) InterceptorFunc { + switch len(interceptors) { + case 0: + return nil + case 1: + return interceptors[0] + default: + return func(ctx context.Context, r *Request, do Do) (*Response, error) { + return interceptors[0](ctx, r, getChainDo(interceptors, 0, do)) + } + } +} + +// getChainDo generates the chained do recursively. +func getChainDo(interceptors []InterceptorFunc, curr int, finalDo Do) Do { + if curr == len(interceptors)-1 { + return finalDo + } + return func(ctx context.Context, r *Request) (*Response, error) { + return interceptors[curr+1](ctx, r, getChainDo(interceptors, curr+1, finalDo)) + } +} diff --git a/method.go b/method.go index 6076f21..8e6cc48 100644 --- a/method.go +++ b/method.go @@ -1,18 +1,16 @@ package requests -import "net/http" - // Get sends an HTTP request with GET method. // // On error, any Response can be ignored. A non-nil Response with a // non-nil error only occurs when Response.StatusCode() is not 2xx. func Get(url string, options ...Option) (*Response, error) { - return callMethod(http.MethodGet, url, options...) + return getDefaultClient().Get(url, options...) } // Post sends an HTTP POST request. func Post(url string, options ...Option) (*Response, error) { - return callMethod(http.MethodPost, url, options...) + return getDefaultClient().Post(url, options...) } // Put sends an HTTP request with PUT method. @@ -20,7 +18,7 @@ func Post(url string, options ...Option) (*Response, error) { // On error, any Response can be ignored. A non-nil Response with a // non-nil error only occurs when Response.StatusCode() is not 2xx. func Put(url string, options ...Option) (*Response, error) { - return callMethod(http.MethodPut, url, options...) + return getDefaultClient().Put(url, options...) } // Patch sends an HTTP request with PATCH method. @@ -28,7 +26,7 @@ func Put(url string, options ...Option) (*Response, error) { // On error, any Response can be ignored. A non-nil Response with a // non-nil error only occurs when Response.StatusCode() is not 2xx. func Patch(url string, options ...Option) (*Response, error) { - return callMethod(http.MethodPatch, url, options...) + return getDefaultClient().Patch(url, options...) } // Delete sends an HTTP request with DELETE method. @@ -36,5 +34,5 @@ func Patch(url string, options ...Option) (*Response, error) { // On error, any Response can be ignored. A non-nil Response with a // non-nil error only occurs when Response.StatusCode() is not 2xx. func Delete(url string, options ...Option) (*Response, error) { - return callMethod(http.MethodDelete, url, options...) + return getDefaultClient().Delete(url, options...) } diff --git a/options.go b/options.go index 870c432..6bee0b5 100644 --- a/options.go +++ b/options.go @@ -37,15 +37,12 @@ type Options struct { // request timeout Timeout time.Duration - DisableKeepAlives bool // dump DumpRequestOut *string DumpResponse *string // interceptor Interceptor InterceptorFunc - // round tripper - RoundTripper http.RoundTripper } // Option is the functional option type. @@ -54,9 +51,9 @@ type Option func(*Options) // newDefaultOptions creates a new default HTTP options. func newDefaultOptions() *Options { return &Options{ + ctx: context.Background(), Headers: http.Header{}, bodyType: bodyTypeDefault, - Timeout: env.timeout, } } @@ -334,29 +331,14 @@ func BasicAuth(username, password string) Option { } } -// Timeout specifies a time limit for requests made by this -// Client. The timeout includes connection time, any -// redirects, and reading the response body. The timer remains -// running after Get, Head, Post, or Do return and will -// interrupt reading of the Response.Body. -// -// A Timeout of zero means no timeout. Default is 60s. +// Timeout creates a new context with specified timeout for +// the current request. func Timeout(timeout time.Duration) Option { return func(opts *Options) { opts.Timeout = timeout } } -// DisableKeepAlives, if true, disables HTTP keep-alives and will -// only use the connection to the server for a single HTTP request. -// -// This is unrelated to the similarly named TCP keep-alives. -func DisableKeepAlives() Option { - return func(opts *Options) { - opts.DisableKeepAlives = true - } -} - // Dump dumps outgoing client request and response to the corresponding // input param (req or resp) if not nil. // @@ -377,12 +359,3 @@ func Interceptor(interceptor InterceptorFunc) Option { opts.Interceptor = interceptor } } - -// Transport specifies a custom RoundTripper for current request only. -// -// NOTE: If specified, then option DisableKeepAlives() will not work. -func Transport(rt http.RoundTripper) Option { - return func(opts *Options) { - opts.RoundTripper = rt - } -} diff --git a/request.go b/request.go index be35142..6ed6820 100644 --- a/request.go +++ b/request.go @@ -5,7 +5,6 @@ package requests import ( "bytes" - "context" "encoding/json" "fmt" "io" @@ -13,7 +12,6 @@ import ( "net/http" "github.com/Wenchy/requests/internal/auth" - "github.com/Wenchy/requests/internal/auth/redirector" ) // Request is a wrapper of http.Request. @@ -23,13 +21,6 @@ type Request struct { body []byte // auto filled from Request.Body } -// WithContext returns a shallow copy of r.Request with its context changed to ctx. -// The provided ctx must be non-nil. -func (r *Request) WithContext(ctx context.Context) *Request { - r.Request = r.Request.WithContext(ctx) - return r -} - // Bytes returns the HTTP request body as []byte. func (r *Request) Bytes() []byte { return r.body @@ -42,7 +33,7 @@ func (r *Request) Text() string { // newRequest creates a new HTTP request. func newRequest(method, url string, opts *Options, body []byte) (*Request, error) { - r, err := http.NewRequest(method, url, opts.Body) + r, err := http.NewRequestWithContext(opts.ctx, method, url, opts.Body) if err != nil { return nil, err } @@ -72,86 +63,8 @@ func newRequest(method, url string, opts *Options, body []byte) (*Request, error return &Request{Request: r, opts: opts, body: body}, nil } -// do sends an HTTP request and returns an HTTP response, following policy -// (such as redirects, cookies, auth) as configured on the client. -func do(method, url string, opts *Options, body []byte) (*Response, error) { - req, err := newRequest(method, url, opts, body) - if err != nil { - return nil, err - } - - // NOTE: Keep-Alive & Connection Pooling - // - // 1. Keep-Alive - // - // The net/http Transport documentation uses the term to refer to - // persistent connections. A keep-alive or persistent connection - // is a connection that can be used for more than one HTTP - // transaction. - // - // The Transport.IdleConnTimeout field specifies how long the - // transport keeps an unused connection in the pool before closing - // the connection. - // - // The net Dialer documentation uses the keep-alive term to refer - // the TCP feature for probing the health of a connection. - // Dialer.KeepAlive field specifies how frequently TCP keep-alive - // probes are sent to the peer. - // - // 2. Connection Pooling - // - // Connections are added to the pool in the function - // Transport.tryPutIdleConn. The connection is not pooled if - // Transport.DisableKeepAlives is true or Transport.MaxIdleConnsPerHost - // is less than zero. - // - // Setting either value disables pooling. The transport adds the - // "Connection: close" request header when DisableKeepAlives is true. - // This may or may not be desirable depending on what you are testing. - // - // 3. References: - // - // - https://stackoverflow.com/questions/57683132/turning-off-connection-pool-for-go-http-client - // - https://stackoverflow.com/questions/59656164/what-is-the-difference-between-net-dialerkeepalive-and-http-transportidletimeo - var roundTripper http.RoundTripper - if opts.RoundTripper != nil { - roundTripper = opts.RoundTripper - } else { - if rt := env.hostRoundTrippers[req.Host]; rt != nil { - // Use the host-specific RoundTripper if set. - roundTripper = rt - } else if opts.DisableKeepAlives { - // If option DisableKeepAlives set as true, then clone a new transport - // just for this one-off HTTP request. - transport := env.transport.Clone() - transport.DisableKeepAlives = true - roundTripper = transport - } else { - // If option DisableKeepAlives not set as true, then use the default - // transport. - roundTripper = env.transport - } - } - client := &Client{ - Client: &http.Client{ - CheckRedirect: redirector.RedirectPolicyFunc, - Timeout: opts.Timeout, - Transport: roundTripper, - }, - } - var ctx context.Context - if opts.ctx != nil { - ctx = opts.ctx // use ctx from options if set - } else { - newCtx, cancel := context.WithTimeout(context.Background(), opts.Timeout) - defer cancel() - ctx = newCtx - } - return client.Do(ctx, req) -} - // request sends an HTTP request. -func request(method, url string, opts *Options) (*Response, error) { +func request(c *Client, method, url string, opts *Options) (*Response, error) { // NOTE: get the body size from io.Reader. It is costy for large body. body := bytes.NewBuffer(nil) if opts.Body != nil { @@ -161,12 +74,12 @@ func request(method, url string, opts *Options) (*Response, error) { } } opts.Body = body - return do(method, url, opts, body.Bytes()) + return c.request(method, url, opts, body.Bytes()) } // requestData sends an HTTP request to the specified URL, with raw string // as the request body. -func requestData(method, url string, opts *Options) (*Response, error) { +func requestData(c *Client, method, url string, opts *Options) (*Response, error) { body := bytes.NewBuffer(nil) if opts.Data != nil { d := fmt.Sprintf("%v", opts.Data) @@ -178,12 +91,12 @@ func requestData(method, url string, opts *Options) (*Response, error) { // TODO: judge content type // opts.Headers["Content-Type"] = "application/x-www-form-urlencoded" opts.Body = body - return do(method, url, opts, body.Bytes()) + return c.request(method, url, opts, body.Bytes()) } // requestForm sends an HTTP request to the specified URL, with form's keys and // values URL-encoded as the request body. -func requestForm(method, url string, opts *Options) (*Response, error) { +func requestForm(c *Client, method, url string, opts *Options) (*Response, error) { body := bytes.NewBuffer(nil) if opts.Form != nil { d := opts.Form.Encode() @@ -194,11 +107,11 @@ func requestForm(method, url string, opts *Options) (*Response, error) { } opts.Headers.Set("Content-Type", "application/x-www-form-urlencoded") opts.Body = body - return do(method, url, opts, body.Bytes()) + return c.request(method, url, opts, body.Bytes()) } // requestJSON sends an HTTP request, and encode request body as json. -func requestJSON(method, url string, opts *Options) (*Response, error) { +func requestJSON(c *Client, method, url string, opts *Options) (*Response, error) { body := bytes.NewBuffer(nil) if opts.JSON != nil { d, err := json.Marshal(opts.JSON) @@ -212,11 +125,11 @@ func requestJSON(method, url string, opts *Options) (*Response, error) { } opts.Headers.Set("Content-Type", "application/json") opts.Body = body - return do(method, url, opts, body.Bytes()) + return c.request(method, url, opts, body.Bytes()) } // requestFiles sends an uploading request for multiple multipart-encoded files. -func requestFiles(method, url string, opts *Options) (*Response, error) { +func requestFiles(c *Client, method, url string, opts *Options) (*Response, error) { body := bytes.NewBuffer(nil) bodyWriter := multipart.NewWriter(body) if opts.Files != nil { @@ -236,7 +149,7 @@ func requestFiles(method, url string, opts *Options) (*Response, error) { } opts.Headers.Set("Content-Type", bodyWriter.FormDataContentType()) opts.Body = body - return do(method, url, opts, body.Bytes()) + return c.request(method, url, opts, body.Bytes()) } type bodyType int @@ -249,21 +162,12 @@ const ( bodyTypeFiles ) -type dispatcher func(method, url string, opts *Options) (*Response, error) - -var dispatchers map[bodyType]dispatcher - -func init() { - dispatchers = map[bodyType]dispatcher{ - bodyTypeDefault: request, - bodyTypeData: requestData, - bodyTypeForm: requestForm, - bodyTypeJSON: requestJSON, - bodyTypeFiles: requestFiles, - } -} +type dispatcher func(c *Client, method, url string, opts *Options) (*Response, error) -func callMethod(method, url string, options ...Option) (*Response, error) { - opts := parseOptions(options...) - return dispatchers[opts.bodyType](method, url, opts) +var dispatchers map[bodyType]dispatcher = map[bodyType]dispatcher{ + bodyTypeDefault: request, + bodyTypeData: requestData, + bodyTypeForm: requestForm, + bodyTypeJSON: requestJSON, + bodyTypeFiles: requestFiles, } diff --git a/request_test.go b/request_test.go index 9b11149..5d2d26b 100644 --- a/request_test.go +++ b/request_test.go @@ -22,7 +22,7 @@ import ( ) func init() { - WithInterceptor(logInterceptor, metricInterceptor, traceInterceptor) + InitDefaultClient(WithInterceptor(ChainInterceptors(logInterceptor, metricInterceptor, traceInterceptor))) } func logInterceptor(ctx context.Context, r *Request, do Do) (*Response, error) { @@ -96,16 +96,6 @@ func TestGet(t *testing.T) { }, wantErr: true, }, - { - name: "disable keep alive", - args: args{ - url: testServer.URL, - options: []Option{ - DisableKeepAlives(), - }, - }, - wantErr: false, - }, { name: "manipulate URLs and query parameters", args: args{ @@ -133,13 +123,13 @@ func TestGet(t *testing.T) { func TestGetWithContext(t *testing.T) { testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - time.Sleep(2 * time.Second) + time.Sleep(100 * time.Millisecond) w.WriteHeader(http.StatusOK) })) defer testServer.Close() - ctx1s, cancel := context.WithTimeout(context.Background(), time.Second) + ctx10ms, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) defer cancel() - ctx5s, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx200ms, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) defer cancel() type args struct { url string @@ -151,21 +141,21 @@ func TestGetWithContext(t *testing.T) { wantErr bool }{ { - name: "with context 1s", + name: "with context 10ms", args: args{ url: testServer.URL, options: []Option{ - Context(ctx1s), + Context(ctx10ms), }, }, wantErr: true, }, { - name: "with context 3s", + name: "with context 200ms", args: args{ url: testServer.URL, options: []Option{ - Context(ctx5s), + Context(ctx200ms), }, }, wantErr: false, @@ -446,7 +436,6 @@ func TestPostJSON(t *testing.T) { type args struct { url string options []Option - timeout time.Duration } tests := []struct { name string @@ -468,16 +457,12 @@ func TestPostJSON(t *testing.T) { ToText(&textResp), Dump(&reqDump, &respDump), }, - timeout: 5 * time.Second, }, wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if tt.args.timeout != 0 { - SetEnvTimeout(tt.args.timeout) - } got, err := Post(tt.args.url, tt.args.options...) if (err != nil) != tt.wantErr { t.Errorf("Get() error = %v, wantErr %v", err, tt.wantErr) @@ -791,100 +776,3 @@ func TestInterceptors(t *testing.T) { }) } } - -type CustomTransport struct { - *http.Transport -} - -func (ct CustomTransport) RoundTrip(req *http.Request) (*http.Response, error) { - req.Header.Set("X-Transport", "CustomTransport") - return ct.Transport.RoundTrip(req) -} - -func TestSetHostTransport(t *testing.T) { - testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - body := r.Header.Get("X-Transport") - n, err := w.Write([]byte(body)) - assert.NoError(t, err) - assert.Equal(t, n, len(body)) - })) - defer testServer.Close() - testServer1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })) - defer testServer1.Close() - testServer2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - body := r.Header.Get("X-Transport") + "(testServer2)" - n, err := w.Write([]byte(body)) - assert.NoError(t, err) - assert.Equal(t, n, len(body)) - })) - defer testServer.Close() - // Parse the URL string - serverURL, err := url.Parse(testServer.URL) - assert.NoError(t, err) - - defaultTransport, ok := http.DefaultTransport.(*http.Transport) - assert.Equal(t, true, ok) - - trans := defaultTransport.Clone() - trans.DisableKeepAlives = true - trans.MaxIdleConns = 1 - trans.IdleConnTimeout = 10 * time.Second - customTransport := CustomTransport{ - Transport: trans, - } - - SetHostTransport(map[string]http.RoundTripper{ - serverURL.Host: customTransport, - }) - type args struct { - url string - options []Option - } - tests := []struct { - name string - args args - wantErr bool - wantBody string - }{ - { - name: "hit host transport at env level", - args: args{ - url: testServer.URL, - }, - wantBody: "CustomTransport", - }, - { - name: "miss host transport at env level", - args: args{ - url: testServer1.URL, - }, - wantBody: "", - }, - { - name: "host transport for current request", - args: args{ - url: testServer2.URL, - options: []Option{ - Transport(customTransport), - }, - }, - wantBody: "CustomTransport(testServer2)", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := Get(tt.args.url, tt.args.options...) - if (err != nil) != tt.wantErr { - t.Errorf("Get() error = %v, wantErr %v", err, tt.wantErr) - return - } - if err == nil { - assert.Equal(t, tt.wantBody, string(got.Bytes())) - } - }) - } -}