Skip to content

Commit 65fcca2

Browse files
committed
Middleware interface
Signed-off-by: Vishal Rana <[email protected]>
1 parent f27de9a commit 65fcca2

File tree

8 files changed

+74
-54
lines changed

8 files changed

+74
-54
lines changed

glide.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

middleware/auth.go

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,38 +12,40 @@ type (
1212
)
1313

1414
const (
15-
Basic = "Basic"
15+
basic = "Basic"
1616
)
1717

1818
// BasicAuth returns an HTTP basic authentication middleware.
1919
//
2020
// For valid credentials it calls the next handler.
2121
// For invalid credentials, it sends "401 - Unauthorized" response.
22-
func BasicAuth(fn BasicValidateFunc) echo.HandlerFunc {
23-
return func(c echo.Context) error {
24-
// Skip WebSocket
25-
if (c.Request().Header().Get(echo.Upgrade)) == echo.WebSocket {
26-
return nil
27-
}
22+
func BasicAuth(fn BasicValidateFunc) MiddlewareFunc {
23+
return func(h echo.HandlerFunc) echo.HandlerFunc {
24+
return func(c echo.Context) error {
25+
// Skip WebSocket
26+
if (c.Request().Header().Get(echo.Upgrade)) == echo.WebSocket {
27+
return nil
28+
}
2829

29-
auth := c.Request().Header().Get(echo.Authorization)
30-
l := len(Basic)
30+
auth := c.Request().Header().Get(echo.Authorization)
31+
l := len(basic)
3132

32-
if len(auth) > l+1 && auth[:l] == Basic {
33-
b, err := base64.StdEncoding.DecodeString(auth[l+1:])
34-
if err == nil {
35-
cred := string(b)
36-
for i := 0; i < len(cred); i++ {
37-
if cred[i] == ':' {
38-
// Verify credentials
39-
if fn(cred[:i], cred[i+1:]) {
40-
return nil
33+
if len(auth) > l+1 && auth[:l] == basic {
34+
b, err := base64.StdEncoding.DecodeString(auth[l+1:])
35+
if err == nil {
36+
cred := string(b)
37+
for i := 0; i < len(cred); i++ {
38+
if cred[i] == ':' {
39+
// Verify credentials
40+
if fn(cred[:i], cred[i+1:]) {
41+
return nil
42+
}
4143
}
4244
}
4345
}
4446
}
47+
c.Response().Header().Set(echo.WWWAuthenticate, basic+" realm=Restricted")
48+
return echo.NewHTTPError(http.StatusUnauthorized)
4549
}
46-
c.Response().Header().Set(echo.WWWAuthenticate, Basic+" realm=Restricted")
47-
return echo.NewHTTPError(http.StatusUnauthorized)
4850
}
4951
}

middleware/auth_test.go

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,38 +21,41 @@ func TestBasicAuth(t *testing.T) {
2121
}
2222
return false
2323
}
24-
ba := BasicAuth(fn)
24+
h := func(c echo.Context) error {
25+
return c.String(http.StatusOK, "test")
26+
}
27+
mw := BasicAuth(fn)(h)
2528

2629
// Valid credentials
27-
auth := Basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
30+
auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
2831
req.Header().Set(echo.Authorization, auth)
29-
assert.NoError(t, ba(c))
32+
assert.NoError(t, mw(c))
3033

3134
//---------------------
3235
// Invalid credentials
3336
//---------------------
3437

3538
// Incorrect password
36-
auth = Basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:password"))
39+
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:password"))
3740
req.Header().Set(echo.Authorization, auth)
38-
he := ba(c).(*echo.HTTPError)
41+
he := mw(c).(*echo.HTTPError)
3942
assert.Equal(t, http.StatusUnauthorized, he.Code())
40-
assert.Equal(t, Basic+" realm=Restricted", res.Header().Get(echo.WWWAuthenticate))
43+
assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.WWWAuthenticate))
4144

4245
// Empty Authorization header
4346
req.Header().Set(echo.Authorization, "")
44-
he = ba(c).(*echo.HTTPError)
47+
he = mw(c).(*echo.HTTPError)
4548
assert.Equal(t, http.StatusUnauthorized, he.Code())
46-
assert.Equal(t, Basic+" realm=Restricted", res.Header().Get(echo.WWWAuthenticate))
49+
assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.WWWAuthenticate))
4750

4851
// Invalid Authorization header
4952
auth = base64.StdEncoding.EncodeToString([]byte("invalid"))
5053
req.Header().Set(echo.Authorization, auth)
51-
he = ba(c).(*echo.HTTPError)
54+
he = mw(c).(*echo.HTTPError)
5255
assert.Equal(t, http.StatusUnauthorized, he.Code())
53-
assert.Equal(t, Basic+" realm=Restricted", res.Header().Get(echo.WWWAuthenticate))
56+
assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.WWWAuthenticate))
5457

5558
// WebSocket
5659
c.Request().Header().Set(echo.Upgrade, echo.WebSocket)
57-
assert.NoError(t, ba(c))
60+
assert.NoError(t, mw(c))
5861
}

middleware/compress.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,10 @@ var writerPool = sync.Pool{
4848

4949
// Gzip returns a middleware which compresses HTTP response using gzip compression
5050
// scheme.
51-
func Gzip() echo.MiddlewareFunc {
52-
scheme := "gzip"
53-
51+
func Gzip() MiddlewareFunc {
5452
return func(h echo.HandlerFunc) echo.HandlerFunc {
53+
scheme := "gzip"
54+
5555
return func(c echo.Context) error {
5656
c.Response().Header().Add(echo.Vary, echo.AcceptEncoding)
5757
if strings.Contains(c.Request().Header().Get(echo.AcceptEncoding), scheme) {

middleware/logger.go renamed to middleware/log.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import (
88
"github.com/labstack/gommon/color"
99
)
1010

11-
func Logger() echo.MiddlewareFunc {
11+
func Log() MiddlewareFunc {
1212
return func(h echo.HandlerFunc) echo.HandlerFunc {
1313
return func(c echo.Context) error {
1414
req := c.Request()

middleware/logger_test.go renamed to middleware/log_test.go

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,34 +12,35 @@ import (
1212
"github.com/stretchr/testify/assert"
1313
)
1414

15-
func TestLogger(t *testing.T) {
15+
func TestLog(t *testing.T) {
1616
// Note: Just for the test coverage, not a real test.
1717
e := echo.New()
1818
req := test.NewRequest(echo.GET, "/", nil)
1919
rec := test.NewResponseRecorder()
2020
c := echo.NewContext(req, rec, e)
21-
22-
// Status 2xx
2321
h := func(c echo.Context) error {
2422
return c.String(http.StatusOK, "test")
2523
}
26-
Logger()(h)(c)
24+
mw := Log()(h)
25+
26+
// Status 2xx
27+
mw(c)
2728

2829
// Status 3xx
2930
rec = test.NewResponseRecorder()
3031
c = echo.NewContext(req, rec, e)
3132
h = func(c echo.Context) error {
3233
return c.String(http.StatusTemporaryRedirect, "test")
3334
}
34-
Logger()(h)(c)
35+
mw(c)
3536

3637
// Status 4xx
3738
rec = test.NewResponseRecorder()
3839
c = echo.NewContext(req, rec, e)
3940
h = func(c echo.Context) error {
4041
return c.String(http.StatusNotFound, "test")
4142
}
42-
Logger()(h)(c)
43+
mw(c)
4344

4445
// Status 5xx with empty path
4546
req = test.NewRequest(echo.GET, "", nil)
@@ -48,10 +49,10 @@ func TestLogger(t *testing.T) {
4849
h = func(c echo.Context) error {
4950
return errors.New("error")
5051
}
51-
Logger()(h)(c)
52+
mw(c)
5253
}
5354

54-
func TestLoggerIPAddress(t *testing.T) {
55+
func TestLogIPAddress(t *testing.T) {
5556
e := echo.New()
5657
req := test.NewRequest(echo.GET, "/", nil)
5758
rec := test.NewResponseRecorder()
@@ -62,23 +63,22 @@ func TestLoggerIPAddress(t *testing.T) {
6263
h := func(c echo.Context) error {
6364
return c.String(http.StatusOK, "test")
6465
}
65-
66-
mw := Logger()
66+
mw := Log()(h)
6767

6868
// With X-Real-IP
6969
req.Header().Add(echo.XRealIP, ip)
70-
mw(h)(c)
70+
mw(c)
7171
assert.Contains(t, buf.String(), ip)
7272

7373
// With X-Forwarded-For
7474
buf.Reset()
7575
req.Header().Del(echo.XRealIP)
7676
req.Header().Add(echo.XForwardedFor, ip)
77-
mw(h)(c)
77+
mw(c)
7878
assert.Contains(t, buf.String(), ip)
7979

8080
// with req.RemoteAddr
8181
buf.Reset()
82-
mw(h)(c)
82+
mw(c)
8383
assert.Contains(t, buf.String(), ip)
8484
}

middleware/middleware.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package middleware
2+
3+
import "github.com/labstack/echo"
4+
5+
type (
6+
Middleware interface {
7+
Process(echo.HandlerFunc) echo.HandlerFunc
8+
}
9+
10+
MiddlewareFunc func(echo.HandlerFunc) echo.HandlerFunc
11+
)
12+
13+
func (f MiddlewareFunc) Process(h echo.HandlerFunc) echo.HandlerFunc {
14+
return f(h)
15+
}

middleware/recover.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ import (
1010

1111
// Recover returns a middleware which recovers from panics anywhere in the chain
1212
// and handles the control to the centralized HTTPErrorHandler.
13-
func Recover() echo.MiddlewareFunc {
14-
// TODO: Provide better stack trace `https://github.com/go-errors/errors` `https://github.com/docker/libcontainer/tree/master/stacktrace`
13+
func Recover() MiddlewareFunc {
1514
return func(h echo.HandlerFunc) echo.HandlerFunc {
15+
// TODO: Provide better stack trace `https://github.com/go-errors/errors` `https://github.com/docker/libcontainer/tree/master/stacktrace`
1616
return func(c echo.Context) error {
1717
defer func() {
1818
if err := recover(); err != nil {

0 commit comments

Comments
 (0)