diff --git a/.gitignore b/.gitignore index fdf6d46..111afc6 100644 --- a/.gitignore +++ b/.gitignore @@ -32,3 +32,4 @@ vendor/ # Vscode files .vscode +.history diff --git a/pkg/providers/http/do.go b/pkg/providers/http/do.go index 6c8e35d..c2757bc 100644 --- a/pkg/providers/http/do.go +++ b/pkg/providers/http/do.go @@ -21,6 +21,7 @@ import ( "crypto/tls" "crypto/x509" "encoding/base64" + "fmt" "io" "net/http" "strings" @@ -34,6 +35,7 @@ import ( wfContext "github.com/kubevela/workflow/pkg/context" "github.com/kubevela/workflow/pkg/cue/model/value" monitorContext "github.com/kubevela/workflow/pkg/monitor/context" + "github.com/kubevela/workflow/pkg/providers/http/ratelimiter" "github.com/kubevela/workflow/pkg/types" ) @@ -42,6 +44,19 @@ const ( ProviderName = "http" ) +var ( + defaultClient *http.Client + rateLimiter *ratelimiter.RateLimiter +) + +func init() { + rateLimiter = ratelimiter.NewRateLimiter(128) + defaultClient = &http.Client{ + Transport: http.DefaultTransport, + Timeout: time.Second * 3, + } +} + type provider struct { cli client.Client ns string @@ -62,17 +77,14 @@ func (h *provider) runHTTP(ctx monitorContext.Context, v *value.Value) (interfac method, u string header, trailer http.Header r io.Reader - client = &http.Client{ - Transport: http.DefaultTransport, - Timeout: time.Second * 3, - } ) + initDefaultClient(defaultClient) if timeout, err := v.GetString("request", "timeout"); err == nil && timeout != "" { duration, err := time.ParseDuration(timeout) if err != nil { return nil, err } - client.Timeout = duration + defaultClient.Timeout = duration } if method, err = v.GetString("method"); err != nil { return nil, err @@ -80,6 +92,23 @@ func (h *provider) runHTTP(ctx monitorContext.Context, v *value.Value) (interfac if u, err = v.GetString("url"); err != nil { return nil, err } + if rl, err := v.LookupValue("request", "ratelimiter"); err == nil { + limit, err := rl.GetInt64("limit") + if err != nil { + return nil, err + } + period, err := rl.GetString("period") + if err != nil { + return nil, err + } + duration, err := time.ParseDuration(period) + if err != nil { + return nil, err + } + if !rateLimiter.Allow(fmt.Sprintf("%s-%s", method, strings.Split(u, "?")[0]), int(limit), duration) { + return nil, errors.New("request exceeds the rate limiter") + } + } if body, err := v.LookupValue("request", "body"); err == nil { r, err = body.CueValue().Reader() if err != nil { @@ -105,10 +134,10 @@ func (h *provider) runHTTP(ctx monitorContext.Context, v *value.Value) (interfac req.Trailer = trailer if tr, err := h.getTransport(ctx, v); err == nil && tr != nil { - client.Transport = tr + defaultClient.Transport = tr } - resp, err := client.Do(req) + resp, err := defaultClient.Do(req) if err != nil { return nil, err } @@ -181,6 +210,11 @@ func (h *provider) getTransport(ctx monitorContext.Context, v *value.Value) (htt return tr, nil } +func initDefaultClient(c *http.Client) { + c.Transport = http.DefaultTransport + c.Timeout = time.Second * 3 +} + func parseHeaders(obj cue.Value, label string) (http.Header, error) { m := obj.LookupPath(value.FieldPath("request", label)) if !m.Exists() { diff --git a/pkg/providers/http/do_test.go b/pkg/providers/http/do_test.go index 4eaa0c7..d4b278a 100644 --- a/pkg/providers/http/do_test.go +++ b/pkg/providers/http/do_test.go @@ -38,6 +38,7 @@ import ( "github.com/kubevela/workflow/pkg/cue/model/value" monitorContext "github.com/kubevela/workflow/pkg/monitor/context" "github.com/kubevela/workflow/pkg/providers" + "github.com/kubevela/workflow/pkg/providers/http/ratelimiter" "github.com/kubevela/workflow/pkg/providers/http/testdata" ) @@ -55,6 +56,10 @@ func TestHttpDo(t *testing.T) { body?: string header?: [string]: string trailer?: [string]: string + ratelimiter?: { + limit: int + period: string + } }) response: close({ body: string @@ -150,6 +155,72 @@ request: { r.NoError(err) r.Equal(ret, tCase.expectedBody, tName) } + + // test ratelimiter + rateLimiter = ratelimiter.NewRateLimiter(1) + limiterTestCases := map[string]struct { + request string + expectedErr string + }{ + "hello": { + request: baseTemplate + ` +method: "GET" +url: "http://127.0.0.1:1229/hello" +request: { + ratelimiter: { + limit: 1 + period: "1m" + } +}`}, + "hello2": { + request: baseTemplate + ` +method: "GET" +url: "http://127.0.0.1:1229/hello?query=1" +request: { + ratelimiter: { + limit: 1 + period: "1m" + } +}`, + expectedErr: "request exceeds the rate limiter", + }, + "echo": { + request: baseTemplate + ` +method: "GET" +url: "http://127.0.0.1:1229/echo" +request: { + ratelimiter: { + limit: 1 + period: "1m" + } +}`, + }, + "hello3": { + request: baseTemplate + ` +method: "GET" +url: "http://127.0.0.1:1229/hello?query=2" +request: { + ratelimiter: { + limit: 1 + period: "1m" + } +}`, + }, + } + + for tName, tCase := range limiterTestCases { + r := require.New(t) + v, err := value.NewValue(tCase.request, nil, "") + r.NoError(err, tName) + prd := &provider{} + err = prd.Do(ctx, nil, v, nil) + if tCase.expectedErr != "" { + r.Error(err) + r.Contains(err.Error(), tCase.expectedErr) + continue + } + r.NoError(err, tName) + } } func TestInstall(t *testing.T) { diff --git a/pkg/providers/http/ratelimiter/ratelimiter.go b/pkg/providers/http/ratelimiter/ratelimiter.go new file mode 100644 index 0000000..b4f7f5e --- /dev/null +++ b/pkg/providers/http/ratelimiter/ratelimiter.go @@ -0,0 +1,49 @@ +/* +Copyright 2022 The KubeVela Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package ratelimiter + +import ( + "time" + + "github.com/golang/groupcache/lru" + "golang.org/x/time/rate" +) + +// RateLimiter is the rate limiter. +type RateLimiter struct { + store *lru.Cache +} + +// NewRateLimiter returns a new rate limiter. +func NewRateLimiter(len int) *RateLimiter { + store := lru.New(len) + store.Clear() + return &RateLimiter{store: store} +} + +// Allow returns true if the operation is allowed. +func (rl *RateLimiter) Allow(id string, limit int, duration time.Duration) bool { + if l, ok := rl.store.Get(id); ok { + limiter := l.(*rate.Limiter) + if limiter.Limit() == rate.Every(duration) && limiter.Burst() == limit { + return limiter.Allow() + } + } + limiter := rate.NewLimiter(rate.Every(duration), limit) + rl.store.Add(id, limiter) + return limiter.Allow() +} diff --git a/pkg/providers/http/ratelimiter/ratelimiter_test.go b/pkg/providers/http/ratelimiter/ratelimiter_test.go new file mode 100644 index 0000000..a86630b --- /dev/null +++ b/pkg/providers/http/ratelimiter/ratelimiter_test.go @@ -0,0 +1,60 @@ +/* +Copyright 2022 The KubeVela Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package ratelimiter + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestRateLimiter(t *testing.T) { + rl := NewRateLimiter(2) + r := require.New(t) + duration := time.Second + testCases := []struct { + id string + limit int + expected bool + }{ + { + id: "1", + limit: 2, + }, + { + id: "2", + limit: 2, + }, + { + id: "3", + limit: 2, + }, + { + id: "2", + limit: 3, + }, + } + for _, tc := range testCases { + for i := 0; i < tc.limit; i++ { + allow := rl.Allow(tc.id, tc.limit, duration) + r.Equal(true, allow) + } + allow := rl.Allow(tc.id, tc.limit, duration) + r.Equal(false, allow) + } +} diff --git a/pkg/stdlib/pkgs/http.cue b/pkg/stdlib/pkgs/http.cue index 16d1d53..320702a 100644 --- a/pkg/stdlib/pkgs/http.cue +++ b/pkg/stdlib/pkgs/http.cue @@ -9,8 +9,8 @@ body?: string header?: [string]: string trailer?: [string]: string - ratelimit?: { - limit: number + ratelimiter?: { + limit: int period: string } }