From babd54bcaa4518ddc3731df4fb53603a42895f09 Mon Sep 17 00:00:00 2001 From: Fergus Strange Date: Sun, 13 Jan 2019 10:15:03 +1100 Subject: [PATCH] Consolidate cookies into one method for request and response. Update all exposed cookie interfaces to use expectedCookie. --- README.md | 13 ++-- apitest.go | 79 +++++-------------- apitest_test.go | 82 +++++++++++++++++--- assert.go | 4 +- cookies.go | 144 +++++++++++++++++++++++++++++++++++ cookies_test.go | 81 ++++++++++++++++++++ examples/gin/api_test.go | 12 +++ examples/gin/server.go | 2 + examples/gorilla/api_test.go | 12 +++ examples/gorilla/server.go | 5 ++ 10 files changed, 354 insertions(+), 80 deletions(-) create mode 100644 cookies.go create mode 100644 cookies_test.go diff --git a/README.md b/README.md index 00cfffa..252f578 100644 --- a/README.md +++ b/README.md @@ -103,15 +103,12 @@ func TestApi(t *testing.T) { Patch("/hello"). Expect(t). Status(http.StatusOK). - Cookies(map[string]string{ - "ABC": "12345", - "DEF": "67890", - }). + Cookies(ExpectedCookie"ABC").Value("12345")). CookiePresent("Session-Token"). CookieNotPresent("XXX"). - HttpCookies([]http.Cookie{ - {Name: "HIJ", Value: "12345"}, - }). + Cookies( + ExpectedCookie"ABC").Value("12345"), + ExpectedCookie"DEF").Value("67890")). End() } ``` @@ -151,7 +148,7 @@ func TestApi(t *testing.T) { apitest.New(). Handler(handler). Get("/hello"). - Cookies(map[string]string{"Cookie1": "Yummy"}). + Cookies(ExpectedCookie"ABC").Value("12345")). Expect(t). Status(http.StatusOK). End() diff --git a/apitest.go b/apitest.go index 674105d..ff38a46 100644 --- a/apitest.go +++ b/apitest.go @@ -1,7 +1,6 @@ package apitest import ( - "bufio" "bytes" "encoding/json" "fmt" @@ -71,7 +70,7 @@ type Request struct { query map[string]string queryCollection map[string][]string headers map[string]string - cookies map[string]string + cookies []*expectedCookie basicAuth string apiTest *APITest } @@ -177,7 +176,7 @@ func (r *Request) Headers(h map[string]string) *Request { } // Cookies is a builder method to set the request cookies -func (r *Request) Cookies(c map[string]string) *Request { +func (r *Request) Cookies(c ...*expectedCookie) *Request { r.cookies = c return r } @@ -200,10 +199,9 @@ type Response struct { status int body string headers map[string]string - cookies map[string]string + cookies []*expectedCookie cookiesPresent []string cookiesNotPresent []string - httpCookies []http.Cookie jsonPathExpression string jsonPathAssert func(interface{}) apiTest *APITest @@ -220,17 +218,11 @@ func (r *Response) Body(b string) *Response { } // Cookies is the expected response cookies -func (r *Response) Cookies(cookies map[string]string) *Response { +func (r *Response) Cookies(cookies ...*expectedCookie) *Response { r.cookies = cookies return r } -// HttpCookies is the expected response cookies -func (r *Response) HttpCookies(cookies []http.Cookie) *Response { - r.httpCookies = cookies - return r -} - // CookiePresent is used to assert that a cookie is present in the response, // regardless of its value func (r *Response) CookiePresent(cookieName string) *Response { @@ -326,9 +318,8 @@ func (a *APITest) BuildRequest() *http.Request { req.Header.Set(k, v) } - for k, v := range a.request.cookies { - cookie := &http.Cookie{Name: k, Value: v} - req.AddCookie(cookie) + for _, cookie := range a.request.cookies { + req.AddCookie(cookie.ToHttpCookie()) } if a.request.basicAuth != "" { @@ -368,77 +359,49 @@ func (a *APITest) assertResponse(res *httptest.ResponseRecorder) { } func (a *APITest) assertCookies(response *httptest.ResponseRecorder) { - if a.response.cookies != nil { - for name, value := range a.response.cookies { + if len(a.response.cookies) > 0 { + for _, expectedCookie := range a.response.cookies { + var mismatchedFields []string foundCookie := false - for _, cookie := range getResponseCookies(response) { - if cookie.Name == name && cookie.Value == value { + for _, actualCookie := range responseCookies(response) { + cookieFound, errors := compareCookies(expectedCookie, actualCookie) + if cookieFound { foundCookie = true + mismatchedFields = append(mismatchedFields, errors...) } } - assertEqual(a.t, true, foundCookie, "Cookie not found - "+name) + assertEqual(a.t, true, foundCookie, "ExpectedCookie not found - "+*expectedCookie.name) + assertEqual(a.t, 0, len(mismatchedFields), mismatchedFields...) } } if len(a.response.cookiesPresent) > 0 { for _, cookieName := range a.response.cookiesPresent { foundCookie := false - for _, cookie := range getResponseCookies(response) { + for _, cookie := range responseCookies(response) { if cookie.Name == cookieName { foundCookie = true } } - assertEqual(a.t, true, foundCookie, "Cookie not found - "+cookieName) + assertEqual(a.t, true, foundCookie, "ExpectedCookie not found - "+cookieName) } } if len(a.response.cookiesNotPresent) > 0 { for _, cookieName := range a.response.cookiesNotPresent { foundCookie := false - for _, cookie := range getResponseCookies(response) { + for _, cookie := range responseCookies(response) { if cookie.Name == cookieName { foundCookie = true } } - assertEqual(a.t, false, foundCookie, "Cookie found - "+cookieName) + assertEqual(a.t, false, foundCookie, "ExpectedCookie found - "+cookieName) } } - - if len(a.response.httpCookies) > 0 { - for _, httpCookie := range a.response.httpCookies { - foundCookie := false - for _, cookie := range getResponseCookies(response) { - if compareHttpCookies(cookie, &httpCookie) { - foundCookie = true - } - } - assertEqual(a.t, true, foundCookie, "Cookie not found - "+httpCookie.Name) - } - } -} - -// only compare a subset of fields for flexibility -func compareHttpCookies(l *http.Cookie, r *http.Cookie) bool { - return l.Name == r.Name && - l.Value == r.Value && - l.Domain == r.Domain && - l.Expires == r.Expires && - l.MaxAge == r.MaxAge && - l.Secure == r.Secure && - l.HttpOnly == r.HttpOnly && - l.SameSite == r.SameSite } -func getResponseCookies(response *httptest.ResponseRecorder) []*http.Cookie { - for _, rawCookieString := range response.Result().Header["Set-Cookie"] { - rawRequest := fmt.Sprintf("GET / HTTP/1.0\r\nCookie: %s\r\n\r\n", rawCookieString) - req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(rawRequest))) - if err != nil { - panic("failed to parse response cookies. error: " + err.Error()) - } - return req.Cookies() - } - return []*http.Cookie{} +func responseCookies(response *httptest.ResponseRecorder) []*http.Cookie { + return response.Result().Cookies() } func (a *APITest) assertHeaders(res *httptest.ResponseRecorder) { diff --git a/apitest_test.go b/apitest_test.go index e9981fe..6667ed8 100644 --- a/apitest_test.go +++ b/apitest_test.go @@ -5,6 +5,7 @@ import ( "net/http" "reflect" "testing" + "time" ) func TestApiTest_AddsJSONBodyToRequest(t *testing.T) { @@ -141,7 +142,8 @@ func TestApiTest_AddsCookiesToRequest(t *testing.T) { Handler(handler). Method(http.MethodGet). URL("/hello"). - Cookies(map[string]string{"Cookie1": "Yummy"}). + Cookies(ExpectedCookie("Cookie1"). + Value("Yummy")). Expect(t). Status(http.StatusOK). End() @@ -214,7 +216,24 @@ func TestApiTest_MatchesTextResponseBody(t *testing.T) { func TestApiTest_MatchesResponseCookies(t *testing.T) { handler := http.NewServeMux() handler.HandleFunc("/hello", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Set-Cookie", "ABC=12345; DEF=67890; XXX=1fsadg235; VVV=9ig32g34g") + w.Header().Set("Set-ExpectedCookie", "ABC=12345; DEF=67890; XXX=1fsadg235; VVV=9ig32g34g") + http.SetCookie(w, &http.Cookie{ + Name: "ABC", + Value: "12345", + }) + http.SetCookie(w, &http.Cookie{ + Name: "DEF", + Value: "67890", + }) + http.SetCookie(w, &http.Cookie{ + Name: "XXX", + Value: "1fsadg235", + }) + http.SetCookie(w, &http.Cookie{ + Name: "VVV", + Value: "9ig32g34g", + }) + w.WriteHeader(http.StatusOK) }) @@ -223,20 +242,27 @@ func TestApiTest_MatchesResponseCookies(t *testing.T) { Patch("/hello"). Expect(t). Status(http.StatusOK). - Cookies(map[string]string{ - "ABC": "12345", - "DEF": "67890", - }). + Cookies( + ExpectedCookie("ABC").Value("12345"), + ExpectedCookie("DEF").Value("67890")). CookiePresent("XXX"). CookiePresent("VVV"). CookieNotPresent("ZZZ"). + CookieNotPresent("TomBeer"). End() } func TestApiTest_MatchesResponseHttpCookies(t *testing.T) { handler := http.NewServeMux() handler.HandleFunc("/hello", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Set-Cookie", "ABC=12345; DEF=67890;") + http.SetCookie(w, &http.Cookie{ + Name: "ABC", + Value: "12345", + }) + http.SetCookie(w, &http.Cookie{ + Name: "DEF", + Value: "67890", + }) w.WriteHeader(http.StatusOK) }) @@ -244,10 +270,42 @@ func TestApiTest_MatchesResponseHttpCookies(t *testing.T) { Handler(handler). Get("/hello"). Expect(t). - HttpCookies([]http.Cookie{ - {Name: "ABC", Value: "12345"}, - {Name: "DEF", Value: "67890"}, - }). + Cookies( + ExpectedCookie("ABC").Value("12345"), + ExpectedCookie("DEF").Value("67890")). + End() +} + +func TestApiTest_MatchesResponseHttpCookies_OnlySuppliedFields(t *testing.T) { + parsedDateTime, err := time.Parse(time.RFC3339, "2019-01-26T23:19:02Z") + if err != nil { + t.Fatalf("%s", err) + } + + handler := http.NewServeMux() + handler.HandleFunc("/hello", func(w http.ResponseWriter, r *http.Request) { + http.SetCookie(w, &http.Cookie{ + Name: "session_id", + Value: "pdsanjdna_8e8922", + Path: "/", + Expires: parsedDateTime, + Secure: true, + HttpOnly: true, + }) + w.WriteHeader(http.StatusOK) + }) + + New(). + Handler(handler). + Get("/hello"). + Expect(t). + Cookies( + ExpectedCookie("session_id"). + Value("pdsanjdna_8e8922"). + Path("/"). + Expires(parsedDateTime). + Secure(true). + HttpOnly(true)). End() } @@ -274,7 +332,7 @@ func TestApiTest_MatchesResponseHeaders(t *testing.T) { func TestApiTest_CustomAssert(t *testing.T) { handler := http.NewServeMux() handler.HandleFunc("/hello", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Set-Cookie", "ABC=12345; DEF=67890; XXX=1fsadg235; VVV=9ig32g34g") + w.Header().Set("Set-ExpectedCookie", "ABC=12345; DEF=67890; XXX=1fsadg235; VVV=9ig32g34g") w.WriteHeader(http.StatusOK) }) diff --git a/assert.go b/assert.go index aa36a30..76b8895 100644 --- a/assert.go +++ b/assert.go @@ -39,9 +39,9 @@ var IsServerError Assert = func(response *http.Response, request *http.Request) func assertEqual(t *testing.T, expected, actual interface{}, message ...string) { if !objectsAreEqual(expected, actual) { if len(message) > 0 { - t.Fatalf(strings.Join(message, ",")) + t.Fatalf(strings.Join(message, ", ")) } else { - t.Fatalf("Expected %s but recevied %s", expected, actual) + t.Fatalf("Expected %+v but recevied %+v", expected, actual) } } } diff --git a/cookies.go b/cookies.go new file mode 100644 index 0000000..fac6ce7 --- /dev/null +++ b/cookies.go @@ -0,0 +1,144 @@ +package apitest + +import ( + "fmt" + "net/http" + "time" +) + +type expectedCookie struct { + name *string + value *string + path *string + domain *string + expires *time.Time + maxAge *int + secure *bool + httpOnly *bool +} + +func ExpectedCookie(name string) *expectedCookie { + return &expectedCookie{ + name: &name, + } +} + +func (cookie *expectedCookie) Value(value string) *expectedCookie { + cookie.value = &value + return cookie +} + +func (cookie *expectedCookie) Path(path string) *expectedCookie { + cookie.path = &path + return cookie +} + +func (cookie *expectedCookie) Domain(domain string) *expectedCookie { + cookie.domain = &domain + return cookie +} + +func (cookie *expectedCookie) Expires(expires time.Time) *expectedCookie { + cookie.expires = &expires + return cookie +} + +func (cookie *expectedCookie) MaxAge(maxAge int) *expectedCookie { + cookie.maxAge = &maxAge + return cookie +} + +func (cookie *expectedCookie) Secure(secure bool) *expectedCookie { + cookie.secure = &secure + return cookie +} + +func (cookie *expectedCookie) HttpOnly(httpOnly bool) *expectedCookie { + cookie.httpOnly = &httpOnly + return cookie +} + +func (cookie *expectedCookie) ToHttpCookie() *http.Cookie { + httpCookie := http.Cookie{} + + if cookie.name != nil { + httpCookie.Name = *cookie.name + } + + if cookie.value != nil { + httpCookie.Value = *cookie.value + } + + if cookie.path != nil { + httpCookie.Path = *cookie.path + } + + if cookie.domain != nil { + httpCookie.Domain = *cookie.domain + } + + if cookie.expires != nil { + httpCookie.Expires = *cookie.expires + } + + if cookie.maxAge != nil { + httpCookie.MaxAge = *cookie.maxAge + } + + if cookie.secure != nil { + httpCookie.Secure = *cookie.secure + } + + if cookie.httpOnly != nil { + httpCookie.HttpOnly = *cookie.httpOnly + } + + return &httpCookie +} + +// Compares cookies based on only the provided fields from expectedCookie. +// Supported fields are Name, Value, Domain, Path, Expires, MaxAge, Secure and HttpOnly +func compareCookies(expectedCookie *expectedCookie, actualCookie *http.Cookie) (bool, []string) { + cookieFound := *expectedCookie.name == actualCookie.Name + compareErrors := []string{} + + if cookieFound { + + formatError := func(name string, expectedValue, actualValue interface{}) string { + return fmt.Sprintf("Missmatched field %s. Expected %v but received %v", + name, + expectedValue, + actualValue) + } + + if expectedCookie.value != nil && *expectedCookie.value != actualCookie.Value { + compareErrors = append(compareErrors, formatError("Value", *expectedCookie.value, actualCookie.Value)) + } + + if expectedCookie.domain != nil && *expectedCookie.domain != actualCookie.Domain { + compareErrors = append(compareErrors, formatError("Domain", *expectedCookie.value, actualCookie.Domain)) + } + + if expectedCookie.path != nil && *expectedCookie.path != actualCookie.Path { + compareErrors = append(compareErrors, formatError("Path", *expectedCookie.path, actualCookie.Path)) + } + + if expectedCookie.expires != nil && !(*expectedCookie.expires).Equal(actualCookie.Expires) { + compareErrors = append(compareErrors, formatError("Expires", *expectedCookie.expires, actualCookie.Expires)) + } + + if expectedCookie.maxAge != nil && *expectedCookie.maxAge != actualCookie.MaxAge { + compareErrors = append(compareErrors, formatError("MaxAge", *expectedCookie.maxAge, actualCookie.MaxAge)) + } + + if expectedCookie.secure != nil && *expectedCookie.secure != actualCookie.Secure { + compareErrors = append(compareErrors, formatError("Secure", *expectedCookie.secure, actualCookie.Secure)) + } + + if expectedCookie.httpOnly != nil && *expectedCookie.httpOnly != actualCookie.HttpOnly { + compareErrors = append(compareErrors, formatError("HttpOnly", *expectedCookie.httpOnly, actualCookie.HttpOnly)) + } + } + + return cookieFound, compareErrors +} diff --git a/cookies_test.go b/cookies_test.go new file mode 100644 index 0000000..9eff2b3 --- /dev/null +++ b/cookies_test.go @@ -0,0 +1,81 @@ +package apitest + +import ( + "net/http" + "testing" + "time" +) + +func TestApiTest_Cookies_ExpectedCookie(t *testing.T) { + expiry, _ := time.Parse("1/2/2006 15:04:05", "03/01/2017 12:00:00") + + cookie := ExpectedCookie("Tom"). + Value("LovesBeers"). + Path("/at-the-lyric"). + Domain("in.london"). + Expires(expiry). + MaxAge(10). + Secure(true). + HttpOnly(false) + + ten := 10 + boolt := true + boolf := false + + assertEqual(t, expectedCookie{ + name: toString("Tom"), + value: toString("LovesBeers"), + path: toString("/at-the-lyric"), + domain: toString("in.london"), + expires: &expiry, + maxAge: &ten, + secure: &boolt, + httpOnly: &boolf, + }, *cookie) +} + +func TestApiTest_Cookies_ToHttpCookie(t *testing.T) { + expiry, _ := time.Parse("1/2/2006 15:04:05", "03/01/2017 12:00:00") + + httpCookie := ExpectedCookie("Tom"). + Value("LovesBeers"). + Path("/at-the-lyric"). + Domain("in.london"). + Expires(expiry). + MaxAge(10). + Secure(true). + HttpOnly(false). + ToHttpCookie() + + assertEqual(t, http.Cookie{ + Name: "Tom", + Value: "LovesBeers", + Path: "/at-the-lyric", + Domain: "in.london", + Expires: expiry, + MaxAge: 10, + Secure: true, + HttpOnly: false, + }, *httpCookie) +} + +func TestApiTest_Cookies_ToHttpCookie_PartiallyCreated(t *testing.T) { + expiry, _ := time.Parse("1/2/2006 15:04:05", "03/01/2017 12:00:00") + + httpCookie := ExpectedCookie("Tom"). + Value("LovesBeers"). + Expires(expiry). + ToHttpCookie() + + assertEqual(t, http.Cookie{ + Name: "Tom", + Value: "LovesBeers", + Expires: expiry, + Secure: false, + HttpOnly: false, + }, *httpCookie) +} + +func toString(str string) *string { + return &str +} diff --git a/examples/gin/api_test.go b/examples/gin/api_test.go index c353f4f..b2b14e6 100644 --- a/examples/gin/api_test.go +++ b/examples/gin/api_test.go @@ -7,6 +7,18 @@ import ( "testing" ) +func TestGetUser_CookieMatching(t *testing.T) { + apitest.New(). + Handler(NewApp().Router). + Get("/user/1234"). + Expect(t). + Cookies(apitest.ExpectedCookie("TomsFavouriteDrink"). + Value("Beer"). + Path("/")). + Status(http.StatusOK). + End() +} + func TestGetUser_Success(t *testing.T) { apitest.New(). Handler(NewApp().Router). diff --git a/examples/gin/server.go b/examples/gin/server.go index ab100c4..0bbec0b 100644 --- a/examples/gin/server.go +++ b/examples/gin/server.go @@ -27,6 +27,8 @@ func (a *App) Start() { func GetUser() gin.HandlerFunc { return func(c *gin.Context) { + c.SetCookie("TomsFavouriteDrink", "Beer", 0, "/", "here.com", false, false) + id := c.Param("id") if id == "1234" { user := &User{ID: id, Name: "Andy"} diff --git a/examples/gorilla/api_test.go b/examples/gorilla/api_test.go index 3872769..7876983 100644 --- a/examples/gorilla/api_test.go +++ b/examples/gorilla/api_test.go @@ -7,6 +7,18 @@ import ( "testing" ) +func TestGetUser_CookieMatching(t *testing.T) { + apitest.New(). + Handler(newApp().Router). + Get("/user/1234"). + Expect(t). + Cookies(apitest.ExpectedCookie("TomsFavouriteDrink"). + Value("Beer"). + Path("/")). + Status(http.StatusOK). + End() +} + func TestGetUser_Success(t *testing.T) { apitest.New(). Handler(newApp().Router). diff --git a/examples/gorilla/server.go b/examples/gorilla/server.go index 9d5e7ba..7a4fba4 100644 --- a/examples/gorilla/server.go +++ b/examples/gorilla/server.go @@ -29,6 +29,11 @@ func (a *App) start() { func getUser() func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) { id := mux.Vars(r)["id"] + http.SetCookie(w, &http.Cookie{ + Name: "TomsFavouriteDrink", + Value: "Beer", + Path: "/", + }) if id == "1234" { user := &User{ID: id, Name: "Andy"}