From 72186726fce2d20386e2be4de7ff109e761703c2 Mon Sep 17 00:00:00 2001 From: Moncef Belyamani Date: Tue, 31 Mar 2020 17:35:16 -0400 Subject: [PATCH] MB-2321 Allow revoking individual user sessions **Description**: In order to obtain an ATO (Authority to Operate) for MilMove, we need to provide a way to revoke individual user sessions. Currently, session management is provided by JWTs per ADR 15, but JWTs aren't designed to be revoked on an individual basis. Instead, we need to store session data on the server. In this PR, we have chosen Redis because it automatically deletes expired sessions. With Postgres, we would need to run a routine periodically to clean up stale sessions. After researching various session management solutions, I chose `scs` because it was the easiest to integrate, and it supports various stores out of the box. It is the second most popular repo after `gorilla/sessions`. I didn't pick `gorilla/sessions` because it suffers from memory leak issues, and uses its own `context` instead of the `request.Context()` provided by Golang. The maintainers are aware of the issues and have opened a GitHub issue to propose improvements for v2. However, it doesn't look like any progress has been made over the past 2 years, while `scs` has implemented most of those improvements. The name of the Redis key that holds the session data is based on the format `scs:session:token`, where `token` is the session cookie value. In order to revoke an individual session, we need to know the token corresponding to the user's session. To facilitate that lookup, I added a new `session-id` to the RequestLogger. **Setup**: `docker pull redis` **Reviewer Notes**: Things to test: **milmovelocal auth** 1. Go to milmovelocal:3000 - [ ] Verify that a session cookie named "mil_session_token" is present (Developer Tools -> Application tab -> Cookies (under Storage in the left sidebar)) - [ ] Verify that the value in the `Expires/Max-Age` column is `Session` - [ ] Verify that the HttpOnly column is checked - [ ] Verify that the Path is `/` - [ ] Verify that `SameSite` is `Lax` 2. In your Terminal, run `redis-cli`, then type `KEYS *` - [ ] Verify there is an entry labeled `scs:session:token`, where `token` is the `Value` of the `mil_session_token` cookie 3. Sign in - [ ] Verify that after successful sign in, the `Value` of the `mil_session_token` cookie changes 4. Run `KEYS *` again in the Redis console - [ ] Verify that the previous entry is gone and that a new one corresponding to the new session cookie is present - [ ] Verify that there is a `session-id` entry in the `middleware/request_logger.go` output that is the same value as the current browser cookie, not the previous one before the user signed in 5. Sign out - [ ] Verify that the previous entry in Redis is gone and that a new one corresponding to the new session cookie is present - [ ] Verify that the session cookie changed in the browser 6. In `serve.go`, on line 504, change the IdleTimeout from 15 minutes to 1 minute 7. Sign in, then wait a little over a minute 8. Try to make a new request without refreshing the browser, for example, filling out the moves form and clicking the Next button - [ ] Verify that you are not able to make a request and that you see an Unauthorized Error. Ideally, the user would be redirected to the sign in page. I'm working on implementing that. **devlocal auth** - [ ] Verify you can sign in and out via devlocal auth flow: http://milmovelocal:3000/devlocal-auth/login - [ ] Verify you can create a New milmove User - [ ] Verify you can create a New dps User **Role based auth** 1. In your `.envrc.local`, add `export FEATURE_FLAG_ROLE_BASED_AUTH=true` 2. Stop the server, run `direnv allow` 3. run `echo $FEATURE_FLAG_ROLE_BASED_AUTH` to make sure it's `true` 4. run `make server_run` 5. Go to milmovelocal:3000 and make sure you can sign in and out **References**: [gorilla sessions issues](https://github.com/gorilla/sessions/issues/105) [scs repo](https://github.com/alexedwards/scs) --- cmd/milmove/serve.go | 61 ++++- go.mod | 6 +- go.sum | 12 +- pkg/auth/authentication/auth.go | 96 +++++--- pkg/auth/authentication/auth_test.go | 185 +++++++-------- pkg/auth/authentication/authgch_test.go | 11 +- pkg/auth/authentication/authghc.go | 1 - pkg/auth/authentication/devlocal.go | 67 +++--- pkg/auth/authentication/devlocal_test.go | 67 ++++-- pkg/auth/cookie.go | 119 +--------- pkg/auth/cookie_test.go | 212 +++--------------- pkg/handlers/internalapi/api.go | 5 +- pkg/handlers/internalapi/service_members.go | 6 +- .../internalapi/service_members_test.go | 14 +- pkg/handlers/responders.go | 28 +-- pkg/middleware/request_logger.go | 6 + 16 files changed, 370 insertions(+), 526 deletions(-) diff --git a/cmd/milmove/serve.go b/cmd/milmove/serve.go index d6443f86f1d9..5bf4f2118430 100644 --- a/cmd/milmove/serve.go +++ b/cmd/milmove/serve.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "crypto/tls" + "encoding/gob" "encoding/hex" "encoding/json" "fmt" @@ -18,13 +19,17 @@ import ( "strings" "sync" "syscall" + "time" + "github.com/alexedwards/scs/redisstore" + "github.com/alexedwards/scs/v2" "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/credentials/stscreds" awssession "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/sts" "github.com/dgrijalva/jwt-go" "github.com/gobuffalo/pop" + "github.com/gomodule/redigo/redis" "github.com/gorilla/csrf" "github.com/pkg/errors" "github.com/spf13/cobra" @@ -477,10 +482,45 @@ func serveFunction(cmd *cobra.Command, args []string) error { logger.Fatal("Registering login provider", zap.Error(err)) } + gob.Register(auth.Session{}) + var sessionManager *scs.SessionManager + sessionManager = scs.New() + pool := &redis.Pool{ + MaxIdle: 10, + Dial: func() (redis.Conn, error) { + return redis.Dial("tcp", "localhost:6379") + }, + } + sessionManager.Store = redisstore.New(pool) + + // IdleTimeout controls the maximum length of time a session can be inactive + // before it expires. For example, some applications may wish to set this so + // there is a timeout after 20 minutes of inactivity. By default IdleTimeout + // is not set and there is no inactivity timeout. + // This preserves the behavior we had when we managed our own session + // cookies, where we extended the session expiry by 15 minutes on every + // request while the session was still valid. + sessionManager.IdleTimeout = 15 * time.Minute + + // Lifetime controls the maximum length of time that a session is valid for + // before it expires. The lifetime is an 'absolute expiry' which is set when + // the session is first created and does not change. The default value is 24 + // hours. + // sessionManager.Lifetime = 24 * time.Hour + + sessionManager.Cookie.Path = "/" + + // A value of false means the session cookie will be deleted when the + // browser is closed. + sessionManager.Cookie.Persist = false + useSecureCookie := !isDevOrTest + if useSecureCookie { + sessionManager.Cookie.Secure = true + } // Session management and authentication middleware noSessionTimeout := v.GetBool(cli.NoSessionTimeoutFlag) - sessionCookieMiddleware := auth.SessionCookieMiddleware(logger, clientAuthSecretKey, noSessionTimeout, appnames, useSecureCookie) + sessionCookieMiddleware := auth.SessionCookieMiddleware(logger, appnames, sessionManager) maskedCSRFMiddleware := auth.MaskedCSRFMiddleware(logger, useSecureCookie) userAuthMiddleware := authentication.UserAuthMiddleware(logger) if v.GetBool(cli.FeatureFlagRoleBasedAuth) { @@ -495,6 +535,7 @@ func serveFunction(cmd *cobra.Command, args []string) error { handlerContext.SetUseSecureCookie(useSecureCookie) if noSessionTimeout { handlerContext.SetNoSessionTimeout() + sessionManager.IdleTimeout = 24 * time.Hour } handlerContext.SetAppNames(appnames) @@ -809,7 +850,7 @@ func serveFunction(cmd *cobra.Command, args []string) error { root.Use(csrf.Protect(csrfAuthKey, csrf.Secure(!isDevOrTest), csrf.Path("/"), csrf.CookieName(auth.GorillaCSRFToken))) root.Use(maskedCSRFMiddleware) - site.Handle(pat.New("/*"), root) + site.Handle(pat.New("/*"), sessionManager.LoadAndSave(root)) if v.GetBool(cli.ServeAPIInternalFlag) { internalMux := goji.SubMux() @@ -827,7 +868,7 @@ func serveFunction(cmd *cobra.Command, args []string) error { internalMux.Handle(pat.New("/*"), internalAPIMux) internalAPIMux.Use(userAuthMiddleware) internalAPIMux.Use(middleware.NoCache(logger)) - api := internalapi.NewInternalAPI(handlerContext) + api := internalapi.NewInternalAPI(handlerContext, sessionManager) internalAPIMux.Handle(pat.New("/*"), api.Serve(nil)) if handlerContext.GetFeatureFlag(cli.FeatureFlagRoleBasedAuth) { internalAPIMux.Use(roleAuthMiddleware(api.Context())) @@ -886,18 +927,18 @@ func serveFunction(cmd *cobra.Command, args []string) error { ) authMux := goji.SubMux() root.Handle(pat.New("/auth/*"), authMux) - authMux.Handle(pat.Get("/login-gov"), authentication.RedirectHandler{Context: authContext, UseSecureCookie: useSecureCookie}) - authMux.Handle(pat.Get("/login-gov/callback"), authentication.NewCallbackHandler(authContext, dbConnection, clientAuthSecretKey, noSessionTimeout, useSecureCookie)) - authMux.Handle(pat.Post("/logout"), authentication.NewLogoutHandler(authContext, clientAuthSecretKey, noSessionTimeout, useSecureCookie)) + authMux.Handle(pat.Get("/login-gov"), authentication.NewRedirectHandler(authContext, sessionManager)) + authMux.Handle(pat.Get("/login-gov/callback"), authentication.NewCallbackHandler(authContext, dbConnection, sessionManager)) + authMux.Handle(pat.Post("/logout"), authentication.NewLogoutHandler(authContext, sessionManager)) if v.GetBool(cli.DevlocalAuthFlag) { logger.Info("Enabling devlocal auth") localAuthMux := goji.SubMux() root.Handle(pat.New("/devlocal-auth/*"), localAuthMux) - localAuthMux.Handle(pat.Get("/login"), authentication.NewUserListHandler(authContext, dbConnection, clientAuthSecretKey, noSessionTimeout, useSecureCookie)) - localAuthMux.Handle(pat.Post("/login"), authentication.NewAssignUserHandler(authContext, dbConnection, appnames, clientAuthSecretKey, noSessionTimeout, useSecureCookie)) - localAuthMux.Handle(pat.Post("/new"), authentication.NewCreateAndLoginUserHandler(authContext, dbConnection, appnames, clientAuthSecretKey, noSessionTimeout, useSecureCookie)) - localAuthMux.Handle(pat.Post("/create"), authentication.NewCreateUserHandler(authContext, dbConnection, appnames, clientAuthSecretKey, noSessionTimeout, useSecureCookie)) + localAuthMux.Handle(pat.Get("/login"), authentication.NewUserListHandler(authContext, dbConnection, sessionManager)) + localAuthMux.Handle(pat.Post("/login"), authentication.NewAssignUserHandler(authContext, dbConnection, appnames, sessionManager)) + localAuthMux.Handle(pat.Post("/new"), authentication.NewCreateAndLoginUserHandler(authContext, dbConnection, appnames, sessionManager)) + localAuthMux.Handle(pat.Post("/create"), authentication.NewCreateUserHandler(authContext, dbConnection, appnames, sessionManager)) if stringSliceContains([]string{cli.EnvironmentTest, cli.EnvironmentDevelopment}, v.GetString(cli.EnvironmentFlag)) { logger.Info("Adding devlocal CA to root CAs") diff --git a/go.mod b/go.mod index 46291eae047f..209b6b5f1c61 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,8 @@ require ( github.com/0xAX/notificator v0.0.0-20191016112426-3962a5ea8da1 // indirect github.com/99designs/aws-vault v4.5.1+incompatible github.com/99designs/keyring v1.1.4 + github.com/alexedwards/scs/redisstore v0.0.0-20200225172727-3308e1066830 + github.com/alexedwards/scs/v2 v2.3.0 github.com/aws/aws-sdk-go v1.30.7 github.com/codegangsta/envy v0.0.0-20141216192214-4b78388c8ce4 // indirect github.com/codegangsta/gin v0.0.0-20171026143024-cafe2ce98974 @@ -34,6 +36,7 @@ require ( github.com/gocarina/gocsv v0.0.0-20190927101021-3ecffd272576 github.com/gofrs/flock v0.7.1 github.com/gofrs/uuid v3.2.0+incompatible + github.com/gomodule/redigo v2.0.0+incompatible github.com/gorilla/csrf v1.6.2 github.com/imdario/mergo v0.3.9 github.com/jessevdk/go-flags v1.4.0 @@ -62,9 +65,10 @@ require ( go.mozilla.org/pkcs7 v0.0.0-20181213175627-3cffc6fbfe83 go.uber.org/zap v1.14.1 goji.io v2.0.2+incompatible - golang.org/x/crypto v0.0.0-20200128174031-69ecbb4d6d5d + golang.org/x/crypto v0.0.0-20200317142112-1b76d66859c6 golang.org/x/net v0.0.0-20200226121028-0de0cce0169b golang.org/x/text v0.3.2 + google.golang.org/appengine v1.6.5 // indirect gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect gopkg.in/go-playground/assert.v1 v1.2.1 // indirect gopkg.in/go-playground/validator.v9 v9.31.0 diff --git a/go.sum b/go.sum index 667e2d3cb545..71615f71f030 100644 --- a/go.sum +++ b/go.sum @@ -21,6 +21,10 @@ github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdko github.com/agnivade/levenshtein v1.0.1/go.mod h1:CURSv5d9Uaml+FovSIICkLbAUZ9S4RqaHDIsdSBg7lM= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= +github.com/alexedwards/scs/redisstore v0.0.0-20200225172727-3308e1066830 h1:84mrg6CV1OZ8RnRZ6zkJ4bvjg/6CHXPPVDKQKecVb+0= +github.com/alexedwards/scs/redisstore v0.0.0-20200225172727-3308e1066830/go.mod h1:u2uSOc9yz8e3S+beMudSPwYL36kcbBChOLBJDAQNy38= +github.com/alexedwards/scs/v2 v2.3.0 h1:V8rtn2P5QGh8C9S7T/ikBo/AdA27vDoQJPbiAaOCmFg= +github.com/alexedwards/scs/v2 v2.3.0/go.mod h1:ToaROZxyKukJKT/xLcVQAChi5k6+Pn1Gvmdl7h3RRj8= github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883/go.mod h1:rCTlJbsFo29Kk6CurOXKm700vrz8f0KW0JNfpkRJY/8= github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8= github.com/asaskevich/govalidator v0.0.0-20180720115003-f9ffefc3facf/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY= @@ -264,6 +268,8 @@ github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y github.com/golang/protobuf v1.3.1 h1:YF8+flBXS5eO826T4nzqPrxfhQThhXl0YzfuUPu4SBg= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/gomodule/redigo v2.0.0+incompatible h1:K/R+8tc58AaqLkqG2Ol3Qk+DR/TlNuhuh457pBFPtt0= +github.com/gomodule/redigo v2.0.0+incompatible/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0 h1:crn/baboCvb5fXaQ0IJ1SGTsTVrWpDsCWC8EGETZijY= @@ -596,8 +602,8 @@ golang.org/x/crypto v0.0.0-20190617133340-57b3e21c3d56/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20190621222207-cc06ce4a13d4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20200128174031-69ecbb4d6d5d h1:9FCpayM9Egr1baVnV1SX0H87m+XB0B8S0hAMi99X/3U= -golang.org/x/crypto v0.0.0-20200128174031-69ecbb4d6d5d/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20200317142112-1b76d66859c6 h1:TjszyFsQsyZNHwdVdZ5m7bjmreu0znc2kRYsEml9/Ww= +golang.org/x/crypto v0.0.0-20200317142112-1b76d66859c6/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/image v0.0.0-20190823064033-3a9bac650e44/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/image v0.0.0-20190910094157-69e4b8554b2a h1:gHevYm0pO4QUbwy8Dmdr01R5r1BuKtfYqRqF0h/Cbh0= golang.org/x/image v0.0.0-20190910094157-69e4b8554b2a/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= @@ -696,6 +702,8 @@ google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9Ywl google.golang.org/appengine v1.2.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.6.1 h1:QzqyMA1tlu6CgqCDUtU9V+ZKhLFT2dkJuANu5QaxI3I= google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0= +google.golang.org/appengine v1.6.5 h1:tycE03LOZYQNhDpS27tcQdAzLCVMaj7QT2SXxebnpCM= +google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.21.0/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= diff --git a/pkg/auth/authentication/auth.go b/pkg/auth/authentication/auth.go index 3962588ec7db..086fd8f7eb95 100644 --- a/pkg/auth/authentication/auth.go +++ b/pkg/auth/authentication/auth.go @@ -13,6 +13,7 @@ import ( "github.com/transcom/mymove/pkg/models/roles" + "github.com/alexedwards/scs/v2" "github.com/gobuffalo/pop" "github.com/gofrs/uuid" "github.com/markbates/goth" @@ -96,6 +97,7 @@ func UserAuthMiddleware(logger Logger) func(next http.Handler) http.Handler { http.Error(w, http.StatusText(401), http.StatusUnauthorized) return } + // DO NOT CHECK MILMOVE SESSION BECAUSE NEW SERVICE MEMBERS WON'T HAVE AN ID RIGHT AWAY // This must be the right type of user for the application if session.IsOfficeApp() && !session.IsOfficeUser() { @@ -258,24 +260,21 @@ func NewAuthContext(logger Logger, loginGovProvider LoginGovProvider, callbackPr // LogoutHandler handles logging the user out of login.gov type LogoutHandler struct { Context - clientAuthSecretKey string - noSessionTimeout bool - useSecureCookie bool + sessionManager *scs.SessionManager } // NewLogoutHandler creates a new LogoutHandler -func NewLogoutHandler(ac Context, clientAuthSecretKey string, noSessionTimeout bool, useSecureCookie bool) LogoutHandler { - handler := LogoutHandler{ - Context: ac, - clientAuthSecretKey: clientAuthSecretKey, - noSessionTimeout: noSessionTimeout, - useSecureCookie: useSecureCookie, +func NewLogoutHandler(ac Context, sessionManager *scs.SessionManager) LogoutHandler { + logoutHandler := LogoutHandler{ + Context: ac, + sessionManager: sessionManager, } - return handler + return logoutHandler } func (h LogoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { session := auth.SessionFromRequestContext(r) + if session != nil { redirectURL := h.landingURL(session) if session.IDToken != "" { @@ -288,11 +287,9 @@ func (h LogoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } else { logoutURL = h.loginGovProvider.LogoutURL(redirectURL, session.IDToken) } - // This operation will delete all cookies from the session - session.IDToken = "" - session.UserID = uuid.Nil - auth.WriteSessionCookie(w, session, h.clientAuthSecretKey, h.noSessionTimeout, h.logger, h.useSecureCookie) + h.sessionManager.Destroy(r.Context()) auth.DeleteCSRFCookies(w) + fmt.Fprint(w, logoutURL) } else { // Can't log out of login.gov without a token, redirect and let them re-auth @@ -309,6 +306,16 @@ const loginStateCookieTTLInSecs = 1800 // 30 mins to transit through login.gov. type RedirectHandler struct { Context UseSecureCookie bool + sessionManager *scs.SessionManager +} + +// NewRedirectHandler creates a new RedirectHandler +func NewRedirectHandler(ac Context, sessionManager *scs.SessionManager) RedirectHandler { + handler := RedirectHandler{ + Context: ac, + sessionManager: sessionManager, + } + return handler } func shaAsString(nonce string) string { @@ -361,28 +368,24 @@ func (h RedirectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // CallbackHandler processes a callback from login.gov type CallbackHandler struct { Context - db *pop.Connection - clientAuthSecretKey string - noSessionTimeout bool - useSecureCookie bool + db *pop.Connection + sessionManager *scs.SessionManager } // NewCallbackHandler creates a new CallbackHandler -func NewCallbackHandler(ac Context, db *pop.Connection, clientAuthSecretKey string, noSessionTimeout bool, useSecureCookie bool) CallbackHandler { +func NewCallbackHandler(ac Context, db *pop.Connection, sessionManager *scs.SessionManager) CallbackHandler { handler := CallbackHandler{ - Context: ac, - db: db, - clientAuthSecretKey: clientAuthSecretKey, - noSessionTimeout: noSessionTimeout, - useSecureCookie: useSecureCookie, + Context: ac, + db: db, + sessionManager: sessionManager, } return handler } // AuthorizationCallbackHandler handles the callback from the Login.gov authorization flow func (h CallbackHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - session := auth.SessionFromRequestContext(r) + if session == nil { h.logger.Error("Session missing") http.Error(w, http.StatusText(500), http.StatusInternalServerError) @@ -390,6 +393,7 @@ func (h CallbackHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } rawLandingURL := h.landingURL(session) + landingURL, err := url.Parse(rawLandingURL) if err != nil { h.logger.Error("Error parsing landing URL") @@ -431,13 +435,12 @@ func (h CallbackHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { zap.String("cookie", hash), zap.String("hash", shaAsString(returnedState))) - // This operation will delete all cookies from the session - session.IDToken = "" - session.UserID = uuid.Nil - auth.WriteSessionCookie(w, session, h.clientAuthSecretKey, h.noSessionTimeout, h.logger, h.useSecureCookie) // Delete lg_state cookie auth.DeleteCookie(w, StateCookieName(session)) + // This operation will delete all cookies from the session + h.sessionManager.Destroy(r.Context()) + // set error query landingQuery := landingURL.Query() landingQuery.Add("error", "SIGNIN_ERROR") @@ -473,6 +476,7 @@ func (h CallbackHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { session.IDToken = openIDSession.IDToken session.Email = openIDUser.Email + h.logger.Info("New Login", zap.String("OID_User", openIDUser.UserID), zap.String("OID_Email", openIDUser.Email), zap.String("Host", session.Hostname)) userIdentity, err := models.FetchUserIdentity(h.db, openIDUser.UserID) @@ -521,7 +525,13 @@ var authorizeUnknownUserNew = func(openIDUser goth.User, h CallbackHandler, sess return } } - auth.WriteSessionCookie(w, session, h.clientAuthSecretKey, h.noSessionTimeout, h.logger, h.useSecureCookie) + err = h.sessionManager.RenewToken(r.Context()) + if err != nil { + http.Error(w, err.Error(), 500) + return + } + h.sessionManager.Put(r.Context(), "session", session) + h.logger.Info("logged in", zap.Any("session", session)) http.Redirect(w, r, h.landingURL(session), http.StatusTemporaryRedirect) return } @@ -635,9 +645,15 @@ var authorizeKnownUserNew = func(userIdentity *models.UserIdentity, h CallbackHa session.LastName = userIdentity.LastName() session.Middle = userIdentity.Middle() + error := h.sessionManager.RenewToken(r.Context()) + if error != nil { + http.Error(w, error.Error(), http.StatusInternalServerError) + return + } + h.sessionManager.Put(r.Context(), "session", session) + h.logger.Info("logged in", zap.Any("session", session)) - auth.WriteSessionCookie(w, session, h.clientAuthSecretKey, h.noSessionTimeout, h.logger, h.useSecureCookie) http.Redirect(w, r, lURL, http.StatusTemporaryRedirect) } @@ -743,9 +759,17 @@ var authorizeKnownUser = func(userIdentity *models.UserIdentity, h CallbackHandl session.LastName = userIdentity.LastName() session.Middle = userIdentity.Middle() + // The session token must be renewed during sign in to prevent + // session fixation attacks + err := h.sessionManager.RenewToken(r.Context()) + if err != nil { + http.Error(w, err.Error(), 500) + return + } + h.sessionManager.Put(r.Context(), "session", session) + h.logger.Info("logged in", zap.Any("session", session)) - auth.WriteSessionCookie(w, session, h.clientAuthSecretKey, h.noSessionTimeout, h.logger, h.useSecureCookie) http.Redirect(w, r, lURL, http.StatusTemporaryRedirect) } @@ -814,9 +838,15 @@ var authorizeUnknownUser = func(openIDUser goth.User, h CallbackHandler, session return } + err = h.sessionManager.RenewToken(r.Context()) + if err != nil { + http.Error(w, err.Error(), 500) + return + } + h.sessionManager.Put(r.Context(), "session", session) + h.logger.Info("logged in", zap.Any("session", session)) - auth.WriteSessionCookie(w, session, h.clientAuthSecretKey, h.noSessionTimeout, h.logger, h.useSecureCookie) http.Redirect(w, r, h.landingURL(session), http.StatusTemporaryRedirect) } diff --git a/pkg/auth/authentication/auth_test.go b/pkg/auth/authentication/auth_test.go index 9e6cb26af7ec..59a5206bcade 100644 --- a/pkg/auth/authentication/auth_test.go +++ b/pkg/auth/authentication/auth_test.go @@ -1,6 +1,8 @@ package authentication import ( + "context" + "encoding/gob" "flag" "fmt" "log" @@ -9,7 +11,10 @@ import ( "net/url" "strconv" "testing" + "time" + "github.com/alexedwards/scs/v2" + "github.com/alexedwards/scs/v2/memstore" "github.com/markbates/goth" middleware "github.com/go-openapi/runtime/middleware" @@ -19,13 +24,11 @@ import ( "github.com/stretchr/testify/suite" "go.uber.org/zap" - "github.com/transcom/mymove/pkg/testdatagen" - "github.com/transcom/mymove/pkg/auth" "github.com/transcom/mymove/pkg/models" - "github.com/transcom/mymove/pkg/testingsuite" - "github.com/transcom/mymove/pkg/models/roles" + "github.com/transcom/mymove/pkg/testdatagen" + "github.com/transcom/mymove/pkg/testingsuite" ) const ( @@ -97,6 +100,7 @@ func (suite *AuthSuite) SetupTest() { if *useNewAuth { authorizeUnknownUser = authorizeUnknownUserNew } + gob.Register(auth.Session{}) } func TestAuthSuite(t *testing.T) { @@ -116,6 +120,23 @@ func fakeLoginGovProvider(logger Logger) LoginGovProvider { return NewLoginGovProvider("fakeHostname", "secret_key", logger) } +func setupScsSession(ctx context.Context, session *auth.Session) (context.Context, *scs.SessionManager) { + var sessionManager *scs.SessionManager + sessionManager = scs.New() + store := memstore.New() + sessionManager.Store = store + + values := make(map[string]interface{}) + values["session"] = session + expiry := time.Now().Add(30 * time.Minute).UTC() + b, _ := sessionManager.Codec.Encode(expiry, values) + + store.Commit("session_token", b, expiry) + scsContext, _ := sessionManager.Load(ctx, "session_token") + sessionManager.Commit(scsContext) + return scsContext, sessionManager +} + func (suite *AuthSuite) TestGenerateNonce() { t := suite.T() nonce := generateNonce() @@ -141,9 +162,11 @@ func (suite *AuthSuite) TestAuthorizationLogoutHandler() { } ctx := auth.SetSessionInRequestContext(req, &session) req = req.WithContext(ctx) + var sessionManager *scs.SessionManager + sessionManager = scs.New() authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) - handler := LogoutHandler{authContext, "fake key", false, false} + handler := sessionManager.LoadAndSave(LogoutHandler{authContext, sessionManager}) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req.WithContext(ctx)) @@ -188,7 +211,9 @@ func (suite *AuthSuite) TestRequireAuthMiddleware() { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handlerSession = auth.SessionFromRequestContext(r) }) - middleware := UserAuthMiddleware(suite.logger)(handler) + var sessionManager *scs.SessionManager + sessionManager = scs.New() + middleware := sessionManager.LoadAndSave(UserAuthMiddleware(suite.logger)(handler)) middleware.ServeHTTP(rr, req) @@ -201,7 +226,9 @@ func (suite *AuthSuite) TestIsLoggedInWhenNoUserLoggedIn() { req := httptest.NewRequest("GET", "/is_logged_in", nil) rr := httptest.NewRecorder() - handler := http.HandlerFunc(IsLoggedInMiddleware(suite.logger)) + var sessionManager *scs.SessionManager + sessionManager = scs.New() + handler := sessionManager.LoadAndSave(IsLoggedInMiddleware(suite.logger)) handler.ServeHTTP(rr, req) @@ -224,13 +251,15 @@ func (suite *AuthSuite) TestIsLoggedInWhenUserLoggedIn() { req := httptest.NewRequest("GET", "/is_logged_in", nil) + var sessionManager *scs.SessionManager + sessionManager = scs.New() // And: the context contains the auth values session := auth.Session{UserID: user.ID, IDToken: "fake Token"} ctx := auth.SetSessionInRequestContext(req, &session) req = req.WithContext(ctx) rr := httptest.NewRecorder() - handler := IsLoggedInMiddleware(suite.logger) + handler := sessionManager.LoadAndSave(IsLoggedInMiddleware(suite.logger)) handler.ServeHTTP(rr, req) @@ -250,7 +279,9 @@ func (suite *AuthSuite) TestRequireAuthMiddlewareUnauthorized() { req := httptest.NewRequest("GET", "/moves", nil) handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) - middleware := UserAuthMiddleware(suite.logger)(handler) + var sessionManager *scs.SessionManager + sessionManager = scs.New() + middleware := sessionManager.LoadAndSave(UserAuthMiddleware(suite.logger)(handler)) middleware.ServeHTTP(rr, req) @@ -329,12 +360,12 @@ func (suite *AuthSuite) TestAuthorizeDeactivateUser() { ctx := auth.SetSessionInRequestContext(req, &session) callbackPort := 1234 authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) + var sessionManager *scs.SessionManager + sessionManager = scs.New() h := CallbackHandler{ authContext, suite.DB(), - "fake key", - false, - false, + sessionManager, } rr := httptest.NewRecorder() authorizeKnownUser(&userIdentity, h, &session, rr, req.WithContext(ctx), "") @@ -362,15 +393,16 @@ func (suite *AuthSuite) TestAuthKnownSingleRoleOffice() { ctx := auth.SetSessionInRequestContext(req, &session) callbackPort := 1234 authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) + + scsContext, sessionManager := setupScsSession(ctx, &session) + h := CallbackHandler{ authContext, suite.DB(), - "fake key", - false, - false, + sessionManager, } rr := httptest.NewRecorder() - authorizeKnownUser(&userIdentity, h, &session, rr, req.WithContext(ctx), "") + authorizeKnownUser(&userIdentity, h, &session, rr, req.WithContext(scsContext), "") // Office app, so should only have office ID information suite.Equal(officeUserID, session.OfficeUserID) @@ -396,12 +428,12 @@ func (suite *AuthSuite) TestAuthorizeDeactivateOfficeUser() { ctx := auth.SetSessionInRequestContext(req, &session) callbackPort := 1234 authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) + var sessionManager *scs.SessionManager + sessionManager = scs.New() h := CallbackHandler{ authContext, suite.DB(), - "fake key", - false, - false, + sessionManager, } rr := httptest.NewRecorder() authorizeKnownUser(&userIdentity, h, &session, rr, req.WithContext(ctx), "") @@ -409,64 +441,6 @@ func (suite *AuthSuite) TestAuthorizeDeactivateOfficeUser() { suite.Equal(http.StatusForbidden, rr.Code, "authorizer did not recognize deactivated office user") } -func (suite *AuthSuite) TestRedirectLoginGovErrorMsg() { - officeUserID := uuid.Must(uuid.NewV4()) - userIdentity := models.UserIdentity{ - Active: true, - OfficeUserID: &officeUserID, - } - - req := httptest.NewRequest("GET", fmt.Sprintf("http://%s/login-gov/callback", OfficeTestHost), nil) - - fakeToken := "some_token" - fakeUUID, _ := uuid.FromString("39b28c92-0506-4bef-8b57-e39519f42dc2") - session := auth.Session{ - ApplicationName: auth.OfficeApp, - UserID: fakeUUID, - IDToken: fakeToken, - Hostname: OfficeTestHost, - } - // login.gov state cookie - cookieName := StateCookieName(&session) - cookie := http.Cookie{ - Name: cookieName, - Value: "some mis-matched hash value", - Path: "/", - Expires: auth.GetExpiryTimeFromMinutes(auth.SessionExpiryInMinutes), - } - req.AddCookie(&cookie) - - ctx := auth.SetSessionInRequestContext(req, &session) - callbackPort := 1234 - authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) - h := CallbackHandler{ - authContext, - suite.DB(), - "fake key", - false, - false, - } - rr := httptest.NewRecorder() - authorizeKnownUser(&userIdentity, h, &session, rr, req.WithContext(ctx), "") - - rr2 := httptest.NewRecorder() - h.ServeHTTP(rr2, req.WithContext(ctx)) - - // Office app, so should only have office ID information - suite.Equal(officeUserID, session.OfficeUserID) - - suite.Equal(2, len(rr2.Result().Cookies())) - // check for blank value for cookie login gov state value and the session cookie value - for _, cookie := range rr2.Result().Cookies() { - if cookie.Name == cookieName || cookie.Name == fmt.Sprintf("%s_%s", string(session.ApplicationName), auth.UserSessionCookieName) { - suite.Equal("blank", cookie.Value) - suite.Equal("/", cookie.Path) - } - } - - suite.Equal("http://office.example.com:1234/?error=SIGNIN_ERROR", rr2.Result().Header.Get("Location")) -} - func (suite *AuthSuite) TestAuthKnownSingleRoleAdmin() { adminUserID := uuid.Must(uuid.NewV4()) officeUserID := uuid.Must(uuid.NewV4()) @@ -493,15 +467,16 @@ func (suite *AuthSuite) TestAuthKnownSingleRoleAdmin() { ctx := auth.SetSessionInRequestContext(req, &session) callbackPort := 1234 authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) + + scsContext, sessionManager := setupScsSession(ctx, &session) + h := CallbackHandler{ authContext, suite.DB(), - "fake key", - false, - false, + sessionManager, } rr := httptest.NewRecorder() - authorizeKnownUser(&userIdentity, h, &session, rr, req.WithContext(ctx), "") + authorizeKnownUser(&userIdentity, h, &session, rr, req.WithContext(scsContext), "") // admin app, so should only have admin ID information suite.Equal(adminUserID, session.AdminUserID) @@ -531,12 +506,12 @@ func (suite *AuthSuite) TestAuthorizeDeactivateAdmin() { ctx := auth.SetSessionInRequestContext(req, &session) callbackPort := 1234 authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) + var sessionManager *scs.SessionManager + sessionManager = scs.New() h := CallbackHandler{ authContext, suite.DB(), - "fake key", - false, - false, + sessionManager, } rr := httptest.NewRecorder() authorizeKnownUser(&userIdentity, h, &session, rr, req.WithContext(ctx), "") @@ -568,12 +543,12 @@ func (suite *AuthSuite) TestAuthorizeUnknownUserOfficeDeactivated() { callbackPort := 1234 authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) + var sessionManager *scs.SessionManager + sessionManager = scs.New() h := CallbackHandler{ authContext, suite.DB(), - "fake key", - false, - false, + sessionManager, } rr := httptest.NewRecorder() @@ -604,12 +579,12 @@ func (suite *AuthSuite) TestAuthorizeUnknownUserOfficeNotFound() { callbackPort := 1234 authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) + var sessionManager *scs.SessionManager + sessionManager = scs.New() h := CallbackHandler{ authContext, suite.DB(), - "fake key", - false, - false, + sessionManager, } rr := httptest.NewRecorder() @@ -651,16 +626,17 @@ func (suite *AuthSuite) TestAuthorizeUnknownUserOfficeLogsIn() { callbackPort := 1234 authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) + + scsContext, sessionManager := setupScsSession(ctx, &session) + h := CallbackHandler{ authContext, suite.DB(), - "fake key", - false, - false, + sessionManager, } rr := httptest.NewRecorder() - authorizeUnknownUser(user, h, &session, rr, req.WithContext(ctx), "") + authorizeUnknownUser(user, h, &session, rr, req.WithContext(scsContext), "") // Office app, so should only have office ID information suite.Equal(officeUser.ID, session.OfficeUserID) @@ -687,12 +663,12 @@ func (suite *AuthSuite) TestAuthorizeUnknownUserAdminDeactivated() { callbackPort := 1234 authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) + var sessionManager *scs.SessionManager + sessionManager = scs.New() h := CallbackHandler{ authContext, suite.DB(), - "fake key", - false, - false, + sessionManager, } rr := httptest.NewRecorder() @@ -723,12 +699,12 @@ func (suite *AuthSuite) TestAuthorizeUnknownUserAdminNotFound() { callbackPort := 1234 authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) + var sessionManager *scs.SessionManager + sessionManager = scs.New() h := CallbackHandler{ authContext, suite.DB(), - "fake key", - false, - false, + sessionManager, } rr := httptest.NewRecorder() @@ -764,16 +740,17 @@ func (suite *AuthSuite) TestAuthorizeUnknownUserAdminLogsIn() { callbackPort := 1234 authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) + + scsContext, sessionManager := setupScsSession(ctx, &session) + h := CallbackHandler{ authContext, suite.DB(), - FakeRSAKey, - false, - false, + sessionManager, } rr := httptest.NewRecorder() - authorizeUnknownUser(user, h, &session, rr, req.WithContext(ctx), "") + authorizeUnknownUser(user, h, &session, rr, req.WithContext(scsContext), "") // Office app, so should only have office ID information suite.Equal(adminUser.ID, session.AdminUserID) diff --git a/pkg/auth/authentication/authgch_test.go b/pkg/auth/authentication/authgch_test.go index 248793efd17a..b77e30d96f39 100644 --- a/pkg/auth/authentication/authgch_test.go +++ b/pkg/auth/authentication/authgch_test.go @@ -5,6 +5,8 @@ import ( "net/http" "net/http/httptest" + "github.com/alexedwards/scs/v2" + "github.com/transcom/mymove/pkg/models/roles" "github.com/transcom/mymove/pkg/cli" @@ -49,17 +51,16 @@ func (suite *AuthSuite) TestCreateTOO() { req.AddCookie(&cookie) callbackPort := 1234 authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) + var sessionManager *scs.SessionManager + sessionManager = scs.New() h := CallbackHandler{ authContext, suite.DB(), - FakeRSAKey, - false, - false, + sessionManager, } rr := httptest.NewRecorder() h.SetFeatureFlag(FeatureFlag{Name: cli.FeatureFlagRoleBasedAuth, Active: true}) - h.ServeHTTP(rr, req) - + sessionManager.LoadAndSave(h).ServeHTTP(rr, req) suite.Equal(rr.Code, 307) } diff --git a/pkg/auth/authentication/authghc.go b/pkg/auth/authentication/authghc.go index 79a397005544..c0d076a7c144 100644 --- a/pkg/auth/authentication/authghc.go +++ b/pkg/auth/authentication/authghc.go @@ -137,7 +137,6 @@ func (uua UnknownUserAuthorizer) AuthorizeUnknownUser(openIDUser goth.User, sess return err } } - uua.logger.Info("logged in", zap.Any("session", session)) return nil } diff --git a/pkg/auth/authentication/devlocal.go b/pkg/auth/authentication/devlocal.go index 497504a9c936..8f101c60a162 100644 --- a/pkg/auth/authentication/devlocal.go +++ b/pkg/auth/authentication/devlocal.go @@ -8,6 +8,7 @@ import ( "strings" "time" + "github.com/alexedwards/scs/v2" "github.com/gobuffalo/pop" "github.com/gofrs/uuid" "github.com/gorilla/csrf" @@ -33,19 +34,15 @@ const ( type UserListHandler struct { db *pop.Connection Context - clientAuthSecretKey string - noSessionTimeout bool - useSecureCookie bool + sessionManager *scs.SessionManager } // NewUserListHandler returns a new UserListHandler -func NewUserListHandler(ac Context, db *pop.Connection, clientAuthSecretKey string, noSessionTimeout bool, useSecureCookie bool) UserListHandler { +func NewUserListHandler(ac Context, db *pop.Connection, sessionManager *scs.SessionManager) UserListHandler { handler := UserListHandler{ - Context: ac, - db: db, - clientAuthSecretKey: clientAuthSecretKey, - noSessionTimeout: noSessionTimeout, - useSecureCookie: useSecureCookie, + Context: ac, + db: db, + sessionManager: sessionManager, } return handler } @@ -57,9 +54,7 @@ func (h UserListHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // User is already authenticated, so clear out their current session and have // them try again. This the issue where a developer will get stuck with a stale // session and have to manually clear cookies to get back to the login page. - session.IDToken = "" - session.UserID = uuid.Nil - auth.WriteSessionCookie(w, session, h.clientAuthSecretKey, h.noSessionTimeout, h.logger, h.useSecureCookie) + h.sessionManager.Destroy(r.Context()) auth.DeleteCSRFCookies(w) http.Redirect(w, r, h.landingURL(session), http.StatusTemporaryRedirect) @@ -199,25 +194,21 @@ func (h UserListHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { type devlocalAuthHandler struct { Context - db *pop.Connection - appnames auth.ApplicationServername - clientAuthSecretKey string - noSessionTimeout bool - useSecureCookie bool + db *pop.Connection + appnames auth.ApplicationServername + sessionManager *scs.SessionManager } // AssignUserHandler logs a user in directly type AssignUserHandler devlocalAuthHandler // NewAssignUserHandler creates a new AssignUserHandler -func NewAssignUserHandler(ac Context, db *pop.Connection, appnames auth.ApplicationServername, clientAuthSecretKey string, noSessionTimeout bool, useSecureCookie bool) AssignUserHandler { +func NewAssignUserHandler(ac Context, db *pop.Connection, appnames auth.ApplicationServername, sessionManager *scs.SessionManager) AssignUserHandler { handler := AssignUserHandler{ - Context: ac, - db: db, - appnames: appnames, - clientAuthSecretKey: clientAuthSecretKey, - noSessionTimeout: noSessionTimeout, - useSecureCookie: useSecureCookie, + Context: ac, + db: db, + appnames: appnames, + sessionManager: sessionManager, } return handler } @@ -269,14 +260,12 @@ func (h AssignUserHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { type CreateUserHandler devlocalAuthHandler // NewCreateUserHandler creates a new CreateUserHandler -func NewCreateUserHandler(ac Context, db *pop.Connection, appnames auth.ApplicationServername, clientAuthSecretKey string, noSessionTimeout bool, useSecureCookie bool) CreateUserHandler { +func NewCreateUserHandler(ac Context, db *pop.Connection, appnames auth.ApplicationServername, sessionManager *scs.SessionManager) CreateUserHandler { handler := CreateUserHandler{ - Context: ac, - db: db, - appnames: appnames, - clientAuthSecretKey: clientAuthSecretKey, - noSessionTimeout: noSessionTimeout, - useSecureCookie: useSecureCookie, + Context: ac, + db: db, + appnames: appnames, + sessionManager: sessionManager, } return handler } @@ -302,14 +291,12 @@ func (h CreateUserHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { type CreateAndLoginUserHandler devlocalAuthHandler // NewCreateAndLoginUserHandler creates a new CreateAndLoginUserHandler -func NewCreateAndLoginUserHandler(ac Context, db *pop.Connection, appnames auth.ApplicationServername, clientAuthSecretKey string, noSessionTimeout bool, useSecureCookie bool) CreateAndLoginUserHandler { +func NewCreateAndLoginUserHandler(ac Context, db *pop.Connection, appnames auth.ApplicationServername, sessionManager *scs.SessionManager) CreateAndLoginUserHandler { handler := CreateAndLoginUserHandler{ - Context: ac, - db: db, - appnames: appnames, - clientAuthSecretKey: clientAuthSecretKey, - noSessionTimeout: noSessionTimeout, - useSecureCookie: useSecureCookie, + Context: ac, + db: db, + appnames: appnames, + sessionManager: sessionManager, } return handler } @@ -535,10 +522,10 @@ func createSession(h devlocalAuthHandler, user *models.User, userType string, w session.LastName = userIdentity.LastName() session.Middle = userIdentity.Middle() + h.sessionManager.Cookie.Name = auth.SessionCookieName(session) + h.sessionManager.Put(r.Context(), "session", session) // Writing out the session cookie logs in the user h.logger.Info("logged in", zap.Any("session", session)) - auth.WriteSessionCookie(w, session, h.clientAuthSecretKey, h.noSessionTimeout, h.logger, h.useSecureCookie) - return session, nil } diff --git a/pkg/auth/authentication/devlocal_test.go b/pkg/auth/authentication/devlocal_test.go index e522c51984ec..865dbfd05f88 100644 --- a/pkg/auth/authentication/devlocal_test.go +++ b/pkg/auth/authentication/devlocal_test.go @@ -8,6 +8,7 @@ import ( "net/url" "strings" + "github.com/alexedwards/scs/v2" "github.com/pkg/errors" "github.com/transcom/mymove/pkg/models" @@ -37,11 +38,13 @@ func (suite *AuthSuite) TestCreateUserHandlerMilMove() { req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") req.ParseForm() + var sessionManager *scs.SessionManager + sessionManager = scs.New() authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) - handler := NewCreateUserHandler(authContext, suite.DB(), appnames, FakeRSAKey, false, false) + handler := NewCreateUserHandler(authContext, suite.DB(), appnames, sessionManager) rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req) + sessionManager.LoadAndSave(handler).ServeHTTP(rr, req) suite.Equal(http.StatusOK, rr.Code, "handler returned wrong status code") if status := rr.Code; status != http.StatusOK { @@ -78,11 +81,13 @@ func (suite *AuthSuite) TestCreateUserHandlerOffice() { req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") req.ParseForm() + var sessionManager *scs.SessionManager + sessionManager = scs.New() authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) - handler := NewCreateUserHandler(authContext, suite.DB(), appnames, FakeRSAKey, false, false) + handler := NewCreateUserHandler(authContext, suite.DB(), appnames, sessionManager) rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req) + sessionManager.LoadAndSave(handler).ServeHTTP(rr, req) suite.Equal(http.StatusOK, rr.Code, "handler returned wrong status code") if status := rr.Code; status != http.StatusOK { @@ -123,11 +128,13 @@ func (suite *AuthSuite) TestCreateUserHandlerDPS() { req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") req.ParseForm() + var sessionManager *scs.SessionManager + sessionManager = scs.New() authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) - handler := NewCreateUserHandler(authContext, suite.DB(), appnames, FakeRSAKey, false, false) + handler := NewCreateUserHandler(authContext, suite.DB(), appnames, sessionManager) rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req) + sessionManager.LoadAndSave(handler).ServeHTTP(rr, req) suite.Equal(http.StatusOK, rr.Code, "handler returned wrong status code") if status := rr.Code; status != http.StatusOK { @@ -164,11 +171,13 @@ func (suite *AuthSuite) TestCreateUserHandlerAdmin() { req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") req.ParseForm() + var sessionManager *scs.SessionManager + sessionManager = scs.New() authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) - handler := NewCreateUserHandler(authContext, suite.DB(), appnames, FakeRSAKey, false, false) + handler := NewCreateUserHandler(authContext, suite.DB(), appnames, sessionManager) rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req) + sessionManager.LoadAndSave(handler).ServeHTTP(rr, req) suite.Equal(http.StatusOK, rr.Code, "handler returned wrong status code") if status := rr.Code; status != http.StatusOK { @@ -210,11 +219,13 @@ func (suite *AuthSuite) TestCreateAndLoginUserHandlerFromMilMoveToMilMove() { req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") req.ParseForm() + var sessionManager *scs.SessionManager + sessionManager = scs.New() authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) - handler := NewCreateAndLoginUserHandler(authContext, suite.DB(), appnames, FakeRSAKey, false, false) + handler := NewCreateAndLoginUserHandler(authContext, suite.DB(), appnames, sessionManager) rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req) + sessionManager.LoadAndSave(handler).ServeHTTP(rr, req) suite.Equal(http.StatusSeeOther, rr.Code, "handler returned wrong status code") if status := rr.Code; status != http.StatusSeeOther { @@ -242,11 +253,13 @@ func (suite *AuthSuite) TestCreateAndLoginUserHandlerFromMilMoveToOffice() { req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") req.ParseForm() + var sessionManager *scs.SessionManager + sessionManager = scs.New() authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) - handler := NewCreateAndLoginUserHandler(authContext, suite.DB(), appnames, FakeRSAKey, false, false) + handler := NewCreateAndLoginUserHandler(authContext, suite.DB(), appnames, sessionManager) rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req) + sessionManager.LoadAndSave(handler).ServeHTTP(rr, req) suite.Equal(http.StatusSeeOther, rr.Code, "handler returned wrong status code") if status := rr.Code; status != http.StatusSeeOther { @@ -274,11 +287,13 @@ func (suite *AuthSuite) TestCreateAndLoginUserHandlerFromMilMoveToAdmin() { req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") req.ParseForm() + var sessionManager *scs.SessionManager + sessionManager = scs.New() authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) - handler := NewCreateAndLoginUserHandler(authContext, suite.DB(), appnames, FakeRSAKey, false, false) + handler := NewCreateAndLoginUserHandler(authContext, suite.DB(), appnames, sessionManager) rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req) + sessionManager.LoadAndSave(handler).ServeHTTP(rr, req) suite.Equal(http.StatusSeeOther, rr.Code, "handler returned wrong status code") if status := rr.Code; status != http.StatusSeeOther { @@ -306,11 +321,13 @@ func (suite *AuthSuite) TestCreateAndLoginUserHandlerFromOfficeToMilMove() { req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") req.ParseForm() + var sessionManager *scs.SessionManager + sessionManager = scs.New() authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) - handler := NewCreateAndLoginUserHandler(authContext, suite.DB(), appnames, FakeRSAKey, false, false) + handler := NewCreateAndLoginUserHandler(authContext, suite.DB(), appnames, sessionManager) rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req) + sessionManager.LoadAndSave(handler).ServeHTTP(rr, req) suite.Equal(http.StatusSeeOther, rr.Code, "handler returned wrong status code") if status := rr.Code; status != http.StatusSeeOther { @@ -338,11 +355,13 @@ func (suite *AuthSuite) TestCreateAndLoginUserHandlerFromOfficeToAdmin() { req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") req.ParseForm() + var sessionManager *scs.SessionManager + sessionManager = scs.New() authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) - handler := NewCreateAndLoginUserHandler(authContext, suite.DB(), appnames, FakeRSAKey, false, false) + handler := NewCreateAndLoginUserHandler(authContext, suite.DB(), appnames, sessionManager) rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req) + sessionManager.LoadAndSave(handler).ServeHTTP(rr, req) suite.Equal(http.StatusSeeOther, rr.Code, "handler returned wrong status code") if status := rr.Code; status != http.StatusSeeOther { @@ -370,11 +389,13 @@ func (suite *AuthSuite) TestCreateAndLoginUserHandlerFromAdminToMilMove() { req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") req.ParseForm() + var sessionManager *scs.SessionManager + sessionManager = scs.New() authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) - handler := NewCreateAndLoginUserHandler(authContext, suite.DB(), appnames, FakeRSAKey, false, false) + handler := NewCreateAndLoginUserHandler(authContext, suite.DB(), appnames, sessionManager) rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req) + sessionManager.LoadAndSave(handler).ServeHTTP(rr, req) suite.Equal(http.StatusSeeOther, rr.Code, "handler returned wrong status code") if status := rr.Code; status != http.StatusSeeOther { @@ -402,11 +423,13 @@ func (suite *AuthSuite) TestCreateAndLoginUserHandlerFromAdminToOffice() { req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") req.ParseForm() + var sessionManager *scs.SessionManager + sessionManager = scs.New() authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) - handler := NewCreateAndLoginUserHandler(authContext, suite.DB(), appnames, FakeRSAKey, false, false) + handler := NewCreateAndLoginUserHandler(authContext, suite.DB(), appnames, sessionManager) rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req) + sessionManager.LoadAndSave(handler).ServeHTTP(rr, req) suite.Equal(http.StatusSeeOther, rr.Code, "handler returned wrong status code") if status := rr.Code; status != http.StatusSeeOther { diff --git a/pkg/auth/cookie.go b/pkg/auth/cookie.go index e2838650af95..29fa36b19d32 100644 --- a/pkg/auth/cookie.go +++ b/pkg/auth/cookie.go @@ -6,8 +6,7 @@ import ( "strings" "time" - jwt "github.com/dgrijalva/jwt-go" - "github.com/gofrs/uuid" + "github.com/alexedwards/scs/v2" "github.com/gorilla/csrf" "github.com/pkg/errors" "go.uber.org/zap" @@ -48,12 +47,6 @@ const MaskedGorillaCSRFToken = "masked_gorilla_csrf" // SessionExpiryInMinutes is the number of minutes before a fallow session is harvested const SessionExpiryInMinutes = 15 -const sessionExpiryInSeconds = 15 * 60 - -// A representable date far in the future. The trouble with something like https://stackoverflow.com/a/32620397 -// is that it produces a date which may not marshall well into JSON which makes logging problematic -var likeForever = time.Date(9999, 1, 1, 12, 0, 0, 0, time.UTC) -var likeForeverInSeconds = 99999999 // GetExpiryTimeFromMinutes returns 'min' minutes from now func GetExpiryTimeFromMinutes(min int64) time.Time { @@ -81,59 +74,9 @@ func DeleteCookie(w http.ResponseWriter, name string) { http.SetCookie(w, &c) } -// SessionClaims wraps StandardClaims with some Session info -type SessionClaims struct { - jwt.StandardClaims - SessionValue Session -} - -func signTokenStringWithUserInfo(expiry time.Time, session *Session, secret string) (string, error) { - claims := SessionClaims{ - StandardClaims: jwt.StandardClaims{ExpiresAt: expiry.Unix()}, - SessionValue: *session, - } - - token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) - rsaKey, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(secret)) - if err != nil { - err = errors.Wrap(err, "Parsing RSA key from PEM") - return "", err - } - - ss, err := token.SignedString(rsaKey) - if err != nil { - err = errors.Wrap(err, "Signing string with token") - return "", err - } - return ss, err -} - -func sessionClaimsFromRequest(logger Logger, secret string, appName Application, r *http.Request) (claims *SessionClaims, ok bool) { - // Name the cookie with the app name - cookieName := fmt.Sprintf("%s_%s", string(appName), UserSessionCookieName) - cookie, err := r.Cookie(cookieName) - if err != nil { - // No cookie set on client - return - } - - token, err := jwt.ParseWithClaims(cookie.Value, &SessionClaims{}, func(token *jwt.Token) (interface{}, error) { - rsaKey, parseRSAPrivateKeyFromPEMErr := jwt.ParseRSAPrivateKeyFromPEM([]byte(secret)) - return &rsaKey.PublicKey, parseRSAPrivateKeyFromPEMErr - }) - - if err != nil || token == nil || !token.Valid { - logger.Error("Failed token validation", zap.Error(err)) - return - } - - // The token actually just stores a Claims interface, so we need to explicitly cast back to UserClaims - claims, ok = token.Claims.(*SessionClaims) - if !ok { - logger.Error("Failed getting claims from token") - return - } - return claims, ok +// SessionCookieName returns the session cookie name +func SessionCookieName(session *Session) string { + return fmt.Sprintf("%s_%s", string(session.ApplicationName), UserSessionCookieName) } // WriteMaskedCSRFCookie update the masked_gorilla_csrf cookie value @@ -171,48 +114,6 @@ func MaskedCSRFMiddleware(logger Logger, useSecureCookie bool) func(next http.Ha } } -// WriteSessionCookie update the cookie for the session -func WriteSessionCookie(w http.ResponseWriter, session *Session, secret string, noSessionTimeout bool, logger Logger, useSecureCookie bool) { - // Delete the cookie - cookieName := fmt.Sprintf("%s_%s", string(session.ApplicationName), UserSessionCookieName) - cookie := http.Cookie{ - Name: cookieName, - Value: "blank", - Path: "/", - Expires: time.Unix(0, 0), - MaxAge: -1, - HttpOnly: true, - SameSite: http.SameSiteLaxMode, // Using 'lax' mode now since 'strict' breaks the use of the login.gov redirect - Secure: useSecureCookie, - } - - // unless we have a valid session - if session.IDToken != "" && session.UserID != uuid.Nil { - expiry := GetExpiryTimeFromMinutes(SessionExpiryInMinutes) - maxAge := sessionExpiryInSeconds - // Never expire token if in development - if noSessionTimeout { - expiry = likeForever - maxAge = likeForeverInSeconds - } - - ss, err := signTokenStringWithUserInfo(expiry, session, secret) - if err != nil { - logger.Error("Generating signed token string", zap.Error(err)) - } else { - logger.Info("Cookie", zap.Int("Size", len(ss))) - cookie.Value = ss - cookie.Expires = expiry - cookie.MaxAge = maxAge - } - } - // http.SetCookie calls Header().Add() instead of .Set(), which can result in duplicate cookies - // It's ok to use this here because we want to delete and rewrite `Set-Cookie` on login or if the - // session token is lost. However, we would normally use http.SetCookie for any other cookie operations - // so as not to delete the session token. - w.Header().Set("Set-Cookie", cookie.String()) -} - // ApplicationName returns the application name given the hostname func ApplicationName(hostname string, appnames ApplicationServername) (Application, error) { var appName Application @@ -233,7 +134,7 @@ func ApplicationName(hostname string, appnames ApplicationServername) (Applicati } // SessionCookieMiddleware handle serializing and de-serializing the session between the user_session cookie and the request context -func SessionCookieMiddleware(serverLogger Logger, secret string, noSessionTimeout bool, appnames ApplicationServername, useSecureCookie bool) func(next http.Handler) http.Handler { +func SessionCookieMiddleware(serverLogger Logger, appnames ApplicationServername, sessionManager *scs.SessionManager) func(next http.Handler) http.Handler { serverLogger.Info("Creating session", zap.String("milServername", appnames.MilServername), zap.String("officeServername", appnames.OfficeServername), @@ -261,9 +162,10 @@ func SessionCookieMiddleware(serverLogger Logger, secret string, noSessionTimeou http.Error(w, http.StatusText(400), http.StatusBadRequest) return } - claims, ok := sessionClaimsFromRequest(logger, secret, appName, r) - if ok { - session = claims.SessionValue + + existingSession := sessionManager.Get(r.Context(), "session") + if existingSession != nil { + session = existingSession.(Session) } // Set more information on the session @@ -271,7 +173,8 @@ func SessionCookieMiddleware(serverLogger Logger, secret string, noSessionTimeou session.Hostname = strings.ToLower(hostname) // And update the cookie. May get over-ridden later - WriteSessionCookie(w, &session, secret, noSessionTimeout, logger, useSecureCookie) + sessionManager.Cookie.Name = SessionCookieName(&session) + sessionManager.Put(r.Context(), "session", session) // And put the session info into the request context next.ServeHTTP(w, r.WithContext(SetSessionInContext(ctx, &session))) diff --git a/pkg/auth/cookie_test.go b/pkg/auth/cookie_test.go index 43cbeec086b8..0fa5e0f79794 100644 --- a/pkg/auth/cookie_test.go +++ b/pkg/auth/cookie_test.go @@ -1,35 +1,18 @@ package auth import ( - "crypto/rand" - "crypto/rsa" - "crypto/x509" - "encoding/pem" + "encoding/gob" "fmt" "net/http" "net/http/httptest" "strings" "time" - "github.com/gofrs/uuid" - "github.com/pkg/errors" + "github.com/alexedwards/scs/v2" ) -func createRandomRSAPEM() (s string, err error) { - priv, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - err = errors.Wrap(err, "failed to generate key") - return - } - - asn1 := x509.MarshalPKCS1PrivateKey(priv) - privBytes := pem.EncodeToMemory(&pem.Block{ - Type: "RSA PRIVATE KEY", - Bytes: asn1, - }) - s = string(privBytes[:]) - - return +func (suite *authSuite) SetupTest() { + gob.Register(Session{}) } func getHandlerParamsWithToken(ss string, expiry time.Time) (*httptest.ResponseRecorder, *http.Request) { @@ -52,24 +35,21 @@ func getHandlerParamsWithToken(ss string, expiry time.Time) (*httptest.ResponseR } func (suite *authSuite) TestSessionCookieMiddlewareWithBadToken() { - t := suite.T() fakeToken := "some_token" - pem, err := createRandomRSAPEM() - if err != nil { - t.Error("error creating RSA key", err) - } + var sessionManager *scs.SessionManager + sessionManager = scs.New() var resultingSession *Session handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { resultingSession = SessionFromRequestContext(r) }) appnames := ApplicationTestServername() - middleware := SessionCookieMiddleware(suite.logger, pem, false, appnames, false)(handler) + middleware := SessionCookieMiddleware(suite.logger, appnames, sessionManager)(handler) expiry := GetExpiryTimeFromMinutes(SessionExpiryInMinutes) rr, req := getHandlerParamsWithToken(fakeToken, expiry) - middleware.ServeHTTP(rr, req) + sessionManager.LoadAndSave(middleware).ServeHTTP(rr, req) // We should be not be redirected since we're not enforcing auth suite.Equal(http.StatusOK, rr.Code, "handler returned wrong status code") @@ -79,142 +59,6 @@ func (suite *authSuite) TestSessionCookieMiddlewareWithBadToken() { suite.Equal("", resultingSession.IDToken, "Expected empty IDToken from bad cookie") } -func (suite *authSuite) TestSessionCookieMiddlewareWithValidToken() { - t := suite.T() - email := "some_email@domain.com" - idToken := "fake_id_token" - fakeUUID, _ := uuid.FromString("39b28c92-0506-4bef-8b57-e39519f42dc2") - - pem, err := createRandomRSAPEM() - if err != nil { - t.Fatal(err) - } - - expiry := GetExpiryTimeFromMinutes(SessionExpiryInMinutes) - incomingSession := Session{ - UserID: fakeUUID, - Email: email, - IDToken: idToken, - } - ss, err := signTokenStringWithUserInfo(expiry, &incomingSession, pem) - if err != nil { - t.Fatal(err) - } - rr, req := getHandlerParamsWithToken(ss, expiry) - - var resultingSession *Session - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - resultingSession = SessionFromRequestContext(r) - }) - appnames := ApplicationTestServername() - middleware := SessionCookieMiddleware(suite.logger, pem, false, appnames, false)(handler) - - middleware.ServeHTTP(rr, req) - - // We should get a 200 OK - suite.Equal(http.StatusOK, rr.Code, "handler returned wrong status code") - - // And there should be an ID token in the request context - suite.NotNil(resultingSession) - suite.Equal(idToken, resultingSession.IDToken, "handler returned wrong id_token") - - // And the cookie should be renewed - setCookies := rr.HeaderMap["Set-Cookie"] - suite.Equal(1, len(setCookies), "expected cookie to be set") -} - -func (suite *authSuite) TestSessionCookieMiddlewareWithExpiredToken() { - t := suite.T() - email := "some_email@domain.com" - idToken := "fake_id_token" - fakeUUID, _ := uuid.FromString("39b28c92-0506-4bef-8b57-e39519f42dc2") - - pem, err := createRandomRSAPEM() - if err != nil { - t.Fatal(err) - } - - expiry := GetExpiryTimeFromMinutes(-1) - incomingSession := Session{ - UserID: fakeUUID, - Email: email, - IDToken: idToken, - } - ss, err := signTokenStringWithUserInfo(expiry, &incomingSession, pem) - if err != nil { - t.Fatal(err) - } - - var resultingSession *Session - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - resultingSession = SessionFromRequestContext(r) - }) - appnames := ApplicationTestServername() - middleware := SessionCookieMiddleware(suite.logger, pem, false, appnames, false)(handler) - - rr, req := getHandlerParamsWithToken(ss, expiry) - - middleware.ServeHTTP(rr, req) - - // We should be not be redirected since we're not enforcing auth - suite.Equal(http.StatusOK, rr.Code, "handler returned wrong status code") - - // And there should be no token passed through - // And there should be no token passed through - suite.NotNil(resultingSession) - suite.Equal("", resultingSession.IDToken, "Expected empty IDToken from expired") - suite.Equal(uuid.Nil, resultingSession.UserID, "Expected no UUID from expired cookie") - - // And the cookie should be set - setCookies := rr.HeaderMap["Set-Cookie"] - suite.Equal(1, len(setCookies), "expected cookie to be set") -} - -func (suite *authSuite) TestSessionCookiePR161162731() { - t := suite.T() - email := "some_email@domain.com" - idToken := "fake_id_token" - fakeUUID, _ := uuid.FromString("39b28c92-0506-4bef-8b57-e39519f42dc2") - - pem, err := createRandomRSAPEM() - if err != nil { - t.Fatal(err) - } - - expiry := GetExpiryTimeFromMinutes(SessionExpiryInMinutes) - incomingSession := Session{ - UserID: fakeUUID, - Email: email, - IDToken: idToken, - } - ss, err := signTokenStringWithUserInfo(expiry, &incomingSession, pem) - if err != nil { - t.Fatal(err) - } - rr, req := getHandlerParamsWithToken(ss, expiry) - - var resultingSession *Session - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - resultingSession = SessionFromRequestContext(r) - WriteSessionCookie(w, resultingSession, "freddy", false, suite.logger, false) - }) - appnames := ApplicationTestServername() - middleware := SessionCookieMiddleware(suite.logger, pem, false, appnames, false)(handler) - - middleware.ServeHTTP(rr, req) - - // We should get a 200 OK - suite.Equal(http.StatusOK, rr.Code, "handler returned wrong status code") - - // And there should be an ID token in the request context - suite.NotNil(resultingSession) - suite.Equal(idToken, resultingSession.IDToken, "handler returned wrong id_token") - - // And the cookie should be renewed - setCookies := rr.HeaderMap["Set-Cookie"] - suite.Equal(1, len(setCookies), "expected cookie to be set") -} - func (suite *authSuite) TestMaskedCSRFMiddleware() { rr := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/", nil) @@ -261,7 +105,9 @@ func (suite *authSuite) TestMaskedCSRFMiddlewareCreatesNewToken() { func (suite *authSuite) TestMiddlewareConstructor() { appnames := ApplicationTestServername() - adm := SessionCookieMiddleware(suite.logger, "secret", false, appnames, false) + var sessionManager *scs.SessionManager + sessionManager = scs.New() + adm := SessionCookieMiddleware(suite.logger, appnames, sessionManager) suite.NotNil(adm) } @@ -276,16 +122,18 @@ func (suite *authSuite) TestMiddlewareMilApp() { suite.False(session.IsAdminApp(), "first should not be admin app") suite.Equal(appnames.MilServername, session.Hostname) }) - milMoveMiddleware := SessionCookieMiddleware(suite.logger, "secret", false, appnames, false)(milMoveTestHandler) + var sessionManager *scs.SessionManager + sessionManager = scs.New() + milMoveMiddleware := SessionCookieMiddleware(suite.logger, appnames, sessionManager)(milMoveTestHandler) req := httptest.NewRequest("GET", fmt.Sprintf("http://%s/some_url", appnames.MilServername), nil) - milMoveMiddleware.ServeHTTP(rr, req) + sessionManager.LoadAndSave(milMoveMiddleware).ServeHTTP(rr, req) req, _ = http.NewRequest("GET", fmt.Sprintf("http://%s:8080/some_url", appnames.MilServername), nil) - milMoveMiddleware.ServeHTTP(rr, req) + sessionManager.LoadAndSave(milMoveMiddleware).ServeHTTP(rr, req) req, _ = http.NewRequest("GET", fmt.Sprintf("http://%s:8080/some_url", strings.ToUpper(appnames.MilServername)), nil) - milMoveMiddleware.ServeHTTP(rr, req) + sessionManager.LoadAndSave(milMoveMiddleware).ServeHTTP(rr, req) } func (suite *authSuite) TestMiddlwareOfficeApp() { @@ -299,16 +147,18 @@ func (suite *authSuite) TestMiddlwareOfficeApp() { suite.False(session.IsAdminApp(), "should not be admin app") suite.Equal(appnames.OfficeServername, session.Hostname) }) - officeMiddleware := SessionCookieMiddleware(suite.logger, "secret", false, appnames, false)(officeTestHandler) + var sessionManager *scs.SessionManager + sessionManager = scs.New() + officeMiddleware := SessionCookieMiddleware(suite.logger, appnames, sessionManager)(officeTestHandler) req := httptest.NewRequest("GET", fmt.Sprintf("http://%s/some_url", appnames.OfficeServername), nil) - officeMiddleware.ServeHTTP(rr, req) + sessionManager.LoadAndSave(officeMiddleware).ServeHTTP(rr, req) req, _ = http.NewRequest("GET", fmt.Sprintf("http://%s:8080/some_url", appnames.OfficeServername), nil) - officeMiddleware.ServeHTTP(rr, req) + sessionManager.LoadAndSave(officeMiddleware).ServeHTTP(rr, req) req, _ = http.NewRequest("GET", fmt.Sprintf("http://%s:8080/some_url", strings.ToUpper(appnames.OfficeServername)), nil) - officeMiddleware.ServeHTTP(rr, req) + sessionManager.LoadAndSave(officeMiddleware).ServeHTTP(rr, req) } func (suite *authSuite) TestMiddlwareAdminApp() { @@ -322,16 +172,18 @@ func (suite *authSuite) TestMiddlwareAdminApp() { suite.True(session.IsAdminApp(), "should be admin app") suite.Equal(AdminTestHost, session.Hostname) }) - adminMiddleware := SessionCookieMiddleware(suite.logger, "secret", false, appnames, false)(adminTestHandler) + var sessionManager *scs.SessionManager + sessionManager = scs.New() + adminMiddleware := SessionCookieMiddleware(suite.logger, appnames, sessionManager)(adminTestHandler) req := httptest.NewRequest("GET", fmt.Sprintf("http://%s/some_url", AdminTestHost), nil) - adminMiddleware.ServeHTTP(rr, req) + sessionManager.LoadAndSave(adminMiddleware).ServeHTTP(rr, req) req, _ = http.NewRequest("GET", fmt.Sprintf("http://%s:8080/some_url", AdminTestHost), nil) - adminMiddleware.ServeHTTP(rr, req) + sessionManager.LoadAndSave(adminMiddleware).ServeHTTP(rr, req) req, _ = http.NewRequest("GET", fmt.Sprintf("http://%s:8080/some_url", strings.ToUpper(AdminTestHost)), nil) - adminMiddleware.ServeHTTP(rr, req) + sessionManager.LoadAndSave(adminMiddleware).ServeHTTP(rr, req) } func (suite *authSuite) TestMiddlewareBadApp() { @@ -341,9 +193,11 @@ func (suite *authSuite) TestMiddlewareBadApp() { suite.Fail("Should not be called") }) appnames := ApplicationTestServername() - noAppMiddleware := SessionCookieMiddleware(suite.logger, "secret", false, appnames, false)(noAppTestHandler) + var sessionManager *scs.SessionManager + sessionManager = scs.New() + noAppMiddleware := SessionCookieMiddleware(suite.logger, appnames, sessionManager)(noAppTestHandler) req := httptest.NewRequest("GET", "http://totally.bogus.hostname/some_url", nil) - noAppMiddleware.ServeHTTP(rr, req) + sessionManager.LoadAndSave(noAppMiddleware).ServeHTTP(rr, req) suite.Equal(http.StatusBadRequest, rr.Code, "Should get an error ") } diff --git a/pkg/handlers/internalapi/api.go b/pkg/handlers/internalapi/api.go index 5db0018ef18a..f37319b0e1a8 100644 --- a/pkg/handlers/internalapi/api.go +++ b/pkg/handlers/internalapi/api.go @@ -8,6 +8,7 @@ import ( movedocument "github.com/transcom/mymove/pkg/services/move_documents" postalcodeservice "github.com/transcom/mymove/pkg/services/postal_codes" + "github.com/alexedwards/scs/v2" "github.com/go-openapi/loads" "github.com/go-openapi/runtime" "github.com/pkg/errors" @@ -18,7 +19,7 @@ import ( ) // NewInternalAPI returns the internal API -func NewInternalAPI(context handlers.HandlerContext) *internalops.MymoveAPI { +func NewInternalAPI(context handlers.HandlerContext, sessionManager *scs.SessionManager) *internalops.MymoveAPI { internalSpec, err := loads.Analyzed(internalapi.SwaggerJSON, "") if err != nil { @@ -69,7 +70,7 @@ func NewInternalAPI(context handlers.HandlerContext) *internalops.MymoveAPI { internalAPI.MoveDocsCreateWeightTicketDocumentHandler = CreateWeightTicketSetDocumentHandler{context} - internalAPI.ServiceMembersCreateServiceMemberHandler = CreateServiceMemberHandler{context} + internalAPI.ServiceMembersCreateServiceMemberHandler = CreateServiceMemberHandler{context, sessionManager} internalAPI.ServiceMembersPatchServiceMemberHandler = PatchServiceMemberHandler{context} internalAPI.ServiceMembersShowServiceMemberHandler = ShowServiceMemberHandler{context} internalAPI.ServiceMembersShowServiceMemberOrdersHandler = ShowServiceMemberOrdersHandler{context} diff --git a/pkg/handlers/internalapi/service_members.go b/pkg/handlers/internalapi/service_members.go index 5add7efbac96..2e43b60b9ffa 100644 --- a/pkg/handlers/internalapi/service_members.go +++ b/pkg/handlers/internalapi/service_members.go @@ -3,6 +3,7 @@ package internalapi import ( "context" + "github.com/alexedwards/scs/v2" "github.com/go-openapi/runtime/middleware" "github.com/gobuffalo/validate" "github.com/gofrs/uuid" @@ -69,6 +70,7 @@ func payloadForServiceMemberModel(storer storage.FileStorer, serviceMember model // CreateServiceMemberHandler creates a new service member via POST /serviceMember type CreateServiceMemberHandler struct { handlers.HandlerContext + sessionManager *scs.SessionManager } // Handle ... creates a new ServiceMember from a request payload @@ -148,10 +150,12 @@ func (h CreateServiceMemberHandler) Handle(params servicememberop.CreateServiceM if newServiceMember.LastName != nil { session.LastName = *(newServiceMember.LastName) } + // Update session cookie here instead of in responders? + // h.sessionManager.Put(ctx, "session", session) // And return serviceMemberPayload := payloadForServiceMemberModel(h.FileStorer(), newServiceMember, h.HandlerContext.GetFeatureFlag(cli.FeatureFlagAccessCode)) responder := servicememberop.NewCreateServiceMemberCreated().WithPayload(serviceMemberPayload) - return handlers.NewCookieUpdateResponder(params.HTTPRequest, h.CookieSecret(), h.NoSessionTimeout(), logger, responder, h.UseSecureCookie()) + return handlers.NewCookieUpdateResponder(params.HTTPRequest, logger, responder, h.sessionManager, session) } // ShowServiceMemberHandler returns a serviceMember for a user and service member ID diff --git a/pkg/handlers/internalapi/service_members_test.go b/pkg/handlers/internalapi/service_members_test.go index 481c3ec6fc51..d3bcb9f55b12 100644 --- a/pkg/handlers/internalapi/service_members_test.go +++ b/pkg/handlers/internalapi/service_members_test.go @@ -6,6 +6,7 @@ import ( "net/http" "net/http/httptest" + "github.com/alexedwards/scs/v2" "github.com/go-openapi/strfmt" "github.com/go-openapi/swag" "github.com/gofrs/uuid" @@ -85,8 +86,9 @@ func (suite *HandlerSuite) TestSubmitServiceMemberHandlerNoValues() { CreateServiceMemberPayload: &newServiceMemberPayload, HTTPRequest: req, } - - handler := CreateServiceMemberHandler{handlers.NewHandlerContext(suite.DB(), suite.TestLogger())} + var sessionManager *scs.SessionManager + sessionManager = scs.New() + handler := CreateServiceMemberHandler{handlers.NewHandlerContext(suite.DB(), suite.TestLogger()), sessionManager} response := handler.Handle(params) suite.Assertions.IsType(&handlers.CookieUpdateResponder{}, response) @@ -151,7 +153,9 @@ func (suite *HandlerSuite) TestSubmitServiceMemberHandlerAllValues() { HTTPRequest: req, } - handler := CreateServiceMemberHandler{handlers.NewHandlerContext(suite.DB(), suite.TestLogger())} + var sessionManager *scs.SessionManager + sessionManager = scs.New() + handler := CreateServiceMemberHandler{handlers.NewHandlerContext(suite.DB(), suite.TestLogger()), sessionManager} response := handler.Handle(params) suite.Assertions.IsType(&handlers.CookieUpdateResponder{}, response) @@ -190,7 +194,9 @@ func (suite *HandlerSuite) TestSubmitServiceMemberSSN() { HTTPRequest: req, } - handler := CreateServiceMemberHandler{handlers.NewHandlerContext(suite.DB(), suite.TestLogger())} + var sessionManager *scs.SessionManager + sessionManager = scs.New() + handler := CreateServiceMemberHandler{handlers.NewHandlerContext(suite.DB(), suite.TestLogger()), sessionManager} response := handler.Handle(params) suite.Assertions.IsType(&handlers.CookieUpdateResponder{}, response) diff --git a/pkg/handlers/responders.go b/pkg/handlers/responders.go index 4ba7220b5387..2bb870b5ef31 100644 --- a/pkg/handlers/responders.go +++ b/pkg/handlers/responders.go @@ -1,8 +1,10 @@ package handlers import ( + "context" "net/http" + "github.com/alexedwards/scs/v2" "github.com/go-openapi/runtime" "github.com/go-openapi/runtime/middleware" @@ -12,28 +14,26 @@ import ( // CookieUpdateResponder wraps a swagger middleware.Responder in code which sets the session_cookie // See: https://github.com/go-swagger/go-swagger/issues/748 type CookieUpdateResponder struct { - session *auth.Session - cookieSecret string - noSessionTimeout bool - logger Logger - Responder middleware.Responder - useSecureCookie bool + session *auth.Session + logger Logger + Responder middleware.Responder + sessionManager *scs.SessionManager + ctx context.Context } // NewCookieUpdateResponder constructs a wrapper for the responder which will update cookies -func NewCookieUpdateResponder(request *http.Request, secret string, noSessionTimeout bool, logger Logger, responder middleware.Responder, useSecureCookie bool) middleware.Responder { +func NewCookieUpdateResponder(request *http.Request, logger Logger, responder middleware.Responder, sessionManager *scs.SessionManager, session *auth.Session) middleware.Responder { return &CookieUpdateResponder{ - session: auth.SessionFromRequestContext(request), - cookieSecret: secret, - noSessionTimeout: noSessionTimeout, - logger: logger, - Responder: responder, - useSecureCookie: useSecureCookie, + session: session, + logger: logger, + Responder: responder, + sessionManager: sessionManager, + ctx: request.Context(), } } // WriteResponse updates the session cookie before writing out the details of the response func (cur *CookieUpdateResponder) WriteResponse(rw http.ResponseWriter, p runtime.Producer) { - auth.WriteSessionCookie(rw, cur.session, cur.cookieSecret, cur.noSessionTimeout, cur.logger, cur.useSecureCookie) + cur.sessionManager.Put(cur.ctx, "session", cur.session) cur.Responder.WriteResponse(rw, p) } diff --git a/pkg/middleware/request_logger.go b/pkg/middleware/request_logger.go index fd7a85649954..8c0fb39950a0 100644 --- a/pkg/middleware/request_logger.go +++ b/pkg/middleware/request_logger.go @@ -63,6 +63,12 @@ func RequestLogger(serverLogger Logger) func(inner http.Handler) http.Handler { if session := auth.SessionFromContext(ctx); session != nil { if session.UserID != uuid.Nil { fields = append(fields, zap.String("user-id", session.UserID.String())) + var sessionID string + cookie, err := r.Cookie(auth.SessionCookieName(session)) + if err == nil { + sessionID = cookie.Value + } + fields = append(fields, zap.String("session-id", sessionID)) } if session.IsServiceMember() { fields = append(fields, zap.String("service-member-id", session.ServiceMemberID.String()))