diff --git a/cmd/milmove/serve.go b/cmd/milmove/serve.go index d6443f86f1d9..5bf4f2118430 100644 --- a/cmd/milmove/serve.go +++ b/cmd/milmove/serve.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "crypto/tls" + "encoding/gob" "encoding/hex" "encoding/json" "fmt" @@ -18,13 +19,17 @@ import ( "strings" "sync" "syscall" + "time" + "github.com/alexedwards/scs/redisstore" + "github.com/alexedwards/scs/v2" "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/credentials/stscreds" awssession "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/sts" "github.com/dgrijalva/jwt-go" "github.com/gobuffalo/pop" + "github.com/gomodule/redigo/redis" "github.com/gorilla/csrf" "github.com/pkg/errors" "github.com/spf13/cobra" @@ -477,10 +482,45 @@ func serveFunction(cmd *cobra.Command, args []string) error { logger.Fatal("Registering login provider", zap.Error(err)) } + gob.Register(auth.Session{}) + var sessionManager *scs.SessionManager + sessionManager = scs.New() + pool := &redis.Pool{ + MaxIdle: 10, + Dial: func() (redis.Conn, error) { + return redis.Dial("tcp", "localhost:6379") + }, + } + sessionManager.Store = redisstore.New(pool) + + // IdleTimeout controls the maximum length of time a session can be inactive + // before it expires. For example, some applications may wish to set this so + // there is a timeout after 20 minutes of inactivity. By default IdleTimeout + // is not set and there is no inactivity timeout. + // This preserves the behavior we had when we managed our own session + // cookies, where we extended the session expiry by 15 minutes on every + // request while the session was still valid. + sessionManager.IdleTimeout = 15 * time.Minute + + // Lifetime controls the maximum length of time that a session is valid for + // before it expires. The lifetime is an 'absolute expiry' which is set when + // the session is first created and does not change. The default value is 24 + // hours. + // sessionManager.Lifetime = 24 * time.Hour + + sessionManager.Cookie.Path = "/" + + // A value of false means the session cookie will be deleted when the + // browser is closed. + sessionManager.Cookie.Persist = false + useSecureCookie := !isDevOrTest + if useSecureCookie { + sessionManager.Cookie.Secure = true + } // Session management and authentication middleware noSessionTimeout := v.GetBool(cli.NoSessionTimeoutFlag) - sessionCookieMiddleware := auth.SessionCookieMiddleware(logger, clientAuthSecretKey, noSessionTimeout, appnames, useSecureCookie) + sessionCookieMiddleware := auth.SessionCookieMiddleware(logger, appnames, sessionManager) maskedCSRFMiddleware := auth.MaskedCSRFMiddleware(logger, useSecureCookie) userAuthMiddleware := authentication.UserAuthMiddleware(logger) if v.GetBool(cli.FeatureFlagRoleBasedAuth) { @@ -495,6 +535,7 @@ func serveFunction(cmd *cobra.Command, args []string) error { handlerContext.SetUseSecureCookie(useSecureCookie) if noSessionTimeout { handlerContext.SetNoSessionTimeout() + sessionManager.IdleTimeout = 24 * time.Hour } handlerContext.SetAppNames(appnames) @@ -809,7 +850,7 @@ func serveFunction(cmd *cobra.Command, args []string) error { root.Use(csrf.Protect(csrfAuthKey, csrf.Secure(!isDevOrTest), csrf.Path("/"), csrf.CookieName(auth.GorillaCSRFToken))) root.Use(maskedCSRFMiddleware) - site.Handle(pat.New("/*"), root) + site.Handle(pat.New("/*"), sessionManager.LoadAndSave(root)) if v.GetBool(cli.ServeAPIInternalFlag) { internalMux := goji.SubMux() @@ -827,7 +868,7 @@ func serveFunction(cmd *cobra.Command, args []string) error { internalMux.Handle(pat.New("/*"), internalAPIMux) internalAPIMux.Use(userAuthMiddleware) internalAPIMux.Use(middleware.NoCache(logger)) - api := internalapi.NewInternalAPI(handlerContext) + api := internalapi.NewInternalAPI(handlerContext, sessionManager) internalAPIMux.Handle(pat.New("/*"), api.Serve(nil)) if handlerContext.GetFeatureFlag(cli.FeatureFlagRoleBasedAuth) { internalAPIMux.Use(roleAuthMiddleware(api.Context())) @@ -886,18 +927,18 @@ func serveFunction(cmd *cobra.Command, args []string) error { ) authMux := goji.SubMux() root.Handle(pat.New("/auth/*"), authMux) - authMux.Handle(pat.Get("/login-gov"), authentication.RedirectHandler{Context: authContext, UseSecureCookie: useSecureCookie}) - authMux.Handle(pat.Get("/login-gov/callback"), authentication.NewCallbackHandler(authContext, dbConnection, clientAuthSecretKey, noSessionTimeout, useSecureCookie)) - authMux.Handle(pat.Post("/logout"), authentication.NewLogoutHandler(authContext, clientAuthSecretKey, noSessionTimeout, useSecureCookie)) + authMux.Handle(pat.Get("/login-gov"), authentication.NewRedirectHandler(authContext, sessionManager)) + authMux.Handle(pat.Get("/login-gov/callback"), authentication.NewCallbackHandler(authContext, dbConnection, sessionManager)) + authMux.Handle(pat.Post("/logout"), authentication.NewLogoutHandler(authContext, sessionManager)) if v.GetBool(cli.DevlocalAuthFlag) { logger.Info("Enabling devlocal auth") localAuthMux := goji.SubMux() root.Handle(pat.New("/devlocal-auth/*"), localAuthMux) - localAuthMux.Handle(pat.Get("/login"), authentication.NewUserListHandler(authContext, dbConnection, clientAuthSecretKey, noSessionTimeout, useSecureCookie)) - localAuthMux.Handle(pat.Post("/login"), authentication.NewAssignUserHandler(authContext, dbConnection, appnames, clientAuthSecretKey, noSessionTimeout, useSecureCookie)) - localAuthMux.Handle(pat.Post("/new"), authentication.NewCreateAndLoginUserHandler(authContext, dbConnection, appnames, clientAuthSecretKey, noSessionTimeout, useSecureCookie)) - localAuthMux.Handle(pat.Post("/create"), authentication.NewCreateUserHandler(authContext, dbConnection, appnames, clientAuthSecretKey, noSessionTimeout, useSecureCookie)) + localAuthMux.Handle(pat.Get("/login"), authentication.NewUserListHandler(authContext, dbConnection, sessionManager)) + localAuthMux.Handle(pat.Post("/login"), authentication.NewAssignUserHandler(authContext, dbConnection, appnames, sessionManager)) + localAuthMux.Handle(pat.Post("/new"), authentication.NewCreateAndLoginUserHandler(authContext, dbConnection, appnames, sessionManager)) + localAuthMux.Handle(pat.Post("/create"), authentication.NewCreateUserHandler(authContext, dbConnection, appnames, sessionManager)) if stringSliceContains([]string{cli.EnvironmentTest, cli.EnvironmentDevelopment}, v.GetString(cli.EnvironmentFlag)) { logger.Info("Adding devlocal CA to root CAs") diff --git a/go.mod b/go.mod index 46291eae047f..209b6b5f1c61 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,8 @@ require ( github.com/0xAX/notificator v0.0.0-20191016112426-3962a5ea8da1 // indirect github.com/99designs/aws-vault v4.5.1+incompatible github.com/99designs/keyring v1.1.4 + github.com/alexedwards/scs/redisstore v0.0.0-20200225172727-3308e1066830 + github.com/alexedwards/scs/v2 v2.3.0 github.com/aws/aws-sdk-go v1.30.7 github.com/codegangsta/envy v0.0.0-20141216192214-4b78388c8ce4 // indirect github.com/codegangsta/gin v0.0.0-20171026143024-cafe2ce98974 @@ -34,6 +36,7 @@ require ( github.com/gocarina/gocsv v0.0.0-20190927101021-3ecffd272576 github.com/gofrs/flock v0.7.1 github.com/gofrs/uuid v3.2.0+incompatible + github.com/gomodule/redigo v2.0.0+incompatible github.com/gorilla/csrf v1.6.2 github.com/imdario/mergo v0.3.9 github.com/jessevdk/go-flags v1.4.0 @@ -62,9 +65,10 @@ require ( go.mozilla.org/pkcs7 v0.0.0-20181213175627-3cffc6fbfe83 go.uber.org/zap v1.14.1 goji.io v2.0.2+incompatible - golang.org/x/crypto v0.0.0-20200128174031-69ecbb4d6d5d + golang.org/x/crypto v0.0.0-20200317142112-1b76d66859c6 golang.org/x/net v0.0.0-20200226121028-0de0cce0169b golang.org/x/text v0.3.2 + google.golang.org/appengine v1.6.5 // indirect gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect gopkg.in/go-playground/assert.v1 v1.2.1 // indirect gopkg.in/go-playground/validator.v9 v9.31.0 diff --git a/go.sum b/go.sum index 667e2d3cb545..71615f71f030 100644 --- a/go.sum +++ b/go.sum @@ -21,6 +21,10 @@ github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdko github.com/agnivade/levenshtein v1.0.1/go.mod h1:CURSv5d9Uaml+FovSIICkLbAUZ9S4RqaHDIsdSBg7lM= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= +github.com/alexedwards/scs/redisstore v0.0.0-20200225172727-3308e1066830 h1:84mrg6CV1OZ8RnRZ6zkJ4bvjg/6CHXPPVDKQKecVb+0= +github.com/alexedwards/scs/redisstore v0.0.0-20200225172727-3308e1066830/go.mod h1:u2uSOc9yz8e3S+beMudSPwYL36kcbBChOLBJDAQNy38= +github.com/alexedwards/scs/v2 v2.3.0 h1:V8rtn2P5QGh8C9S7T/ikBo/AdA27vDoQJPbiAaOCmFg= +github.com/alexedwards/scs/v2 v2.3.0/go.mod h1:ToaROZxyKukJKT/xLcVQAChi5k6+Pn1Gvmdl7h3RRj8= github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883/go.mod h1:rCTlJbsFo29Kk6CurOXKm700vrz8f0KW0JNfpkRJY/8= github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8= github.com/asaskevich/govalidator v0.0.0-20180720115003-f9ffefc3facf/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY= @@ -264,6 +268,8 @@ github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y github.com/golang/protobuf v1.3.1 h1:YF8+flBXS5eO826T4nzqPrxfhQThhXl0YzfuUPu4SBg= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/gomodule/redigo v2.0.0+incompatible h1:K/R+8tc58AaqLkqG2Ol3Qk+DR/TlNuhuh457pBFPtt0= +github.com/gomodule/redigo v2.0.0+incompatible/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0 h1:crn/baboCvb5fXaQ0IJ1SGTsTVrWpDsCWC8EGETZijY= @@ -596,8 +602,8 @@ golang.org/x/crypto v0.0.0-20190617133340-57b3e21c3d56/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20190621222207-cc06ce4a13d4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20200128174031-69ecbb4d6d5d h1:9FCpayM9Egr1baVnV1SX0H87m+XB0B8S0hAMi99X/3U= -golang.org/x/crypto v0.0.0-20200128174031-69ecbb4d6d5d/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20200317142112-1b76d66859c6 h1:TjszyFsQsyZNHwdVdZ5m7bjmreu0znc2kRYsEml9/Ww= +golang.org/x/crypto v0.0.0-20200317142112-1b76d66859c6/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/image v0.0.0-20190823064033-3a9bac650e44/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/image v0.0.0-20190910094157-69e4b8554b2a h1:gHevYm0pO4QUbwy8Dmdr01R5r1BuKtfYqRqF0h/Cbh0= golang.org/x/image v0.0.0-20190910094157-69e4b8554b2a/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= @@ -696,6 +702,8 @@ google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9Ywl google.golang.org/appengine v1.2.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.6.1 h1:QzqyMA1tlu6CgqCDUtU9V+ZKhLFT2dkJuANu5QaxI3I= google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0= +google.golang.org/appengine v1.6.5 h1:tycE03LOZYQNhDpS27tcQdAzLCVMaj7QT2SXxebnpCM= +google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.21.0/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= diff --git a/pkg/auth/authentication/auth.go b/pkg/auth/authentication/auth.go index 3962588ec7db..086fd8f7eb95 100644 --- a/pkg/auth/authentication/auth.go +++ b/pkg/auth/authentication/auth.go @@ -13,6 +13,7 @@ import ( "github.com/transcom/mymove/pkg/models/roles" + "github.com/alexedwards/scs/v2" "github.com/gobuffalo/pop" "github.com/gofrs/uuid" "github.com/markbates/goth" @@ -96,6 +97,7 @@ func UserAuthMiddleware(logger Logger) func(next http.Handler) http.Handler { http.Error(w, http.StatusText(401), http.StatusUnauthorized) return } + // DO NOT CHECK MILMOVE SESSION BECAUSE NEW SERVICE MEMBERS WON'T HAVE AN ID RIGHT AWAY // This must be the right type of user for the application if session.IsOfficeApp() && !session.IsOfficeUser() { @@ -258,24 +260,21 @@ func NewAuthContext(logger Logger, loginGovProvider LoginGovProvider, callbackPr // LogoutHandler handles logging the user out of login.gov type LogoutHandler struct { Context - clientAuthSecretKey string - noSessionTimeout bool - useSecureCookie bool + sessionManager *scs.SessionManager } // NewLogoutHandler creates a new LogoutHandler -func NewLogoutHandler(ac Context, clientAuthSecretKey string, noSessionTimeout bool, useSecureCookie bool) LogoutHandler { - handler := LogoutHandler{ - Context: ac, - clientAuthSecretKey: clientAuthSecretKey, - noSessionTimeout: noSessionTimeout, - useSecureCookie: useSecureCookie, +func NewLogoutHandler(ac Context, sessionManager *scs.SessionManager) LogoutHandler { + logoutHandler := LogoutHandler{ + Context: ac, + sessionManager: sessionManager, } - return handler + return logoutHandler } func (h LogoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { session := auth.SessionFromRequestContext(r) + if session != nil { redirectURL := h.landingURL(session) if session.IDToken != "" { @@ -288,11 +287,9 @@ func (h LogoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } else { logoutURL = h.loginGovProvider.LogoutURL(redirectURL, session.IDToken) } - // This operation will delete all cookies from the session - session.IDToken = "" - session.UserID = uuid.Nil - auth.WriteSessionCookie(w, session, h.clientAuthSecretKey, h.noSessionTimeout, h.logger, h.useSecureCookie) + h.sessionManager.Destroy(r.Context()) auth.DeleteCSRFCookies(w) + fmt.Fprint(w, logoutURL) } else { // Can't log out of login.gov without a token, redirect and let them re-auth @@ -309,6 +306,16 @@ const loginStateCookieTTLInSecs = 1800 // 30 mins to transit through login.gov. type RedirectHandler struct { Context UseSecureCookie bool + sessionManager *scs.SessionManager +} + +// NewRedirectHandler creates a new RedirectHandler +func NewRedirectHandler(ac Context, sessionManager *scs.SessionManager) RedirectHandler { + handler := RedirectHandler{ + Context: ac, + sessionManager: sessionManager, + } + return handler } func shaAsString(nonce string) string { @@ -361,28 +368,24 @@ func (h RedirectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // CallbackHandler processes a callback from login.gov type CallbackHandler struct { Context - db *pop.Connection - clientAuthSecretKey string - noSessionTimeout bool - useSecureCookie bool + db *pop.Connection + sessionManager *scs.SessionManager } // NewCallbackHandler creates a new CallbackHandler -func NewCallbackHandler(ac Context, db *pop.Connection, clientAuthSecretKey string, noSessionTimeout bool, useSecureCookie bool) CallbackHandler { +func NewCallbackHandler(ac Context, db *pop.Connection, sessionManager *scs.SessionManager) CallbackHandler { handler := CallbackHandler{ - Context: ac, - db: db, - clientAuthSecretKey: clientAuthSecretKey, - noSessionTimeout: noSessionTimeout, - useSecureCookie: useSecureCookie, + Context: ac, + db: db, + sessionManager: sessionManager, } return handler } // AuthorizationCallbackHandler handles the callback from the Login.gov authorization flow func (h CallbackHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - session := auth.SessionFromRequestContext(r) + if session == nil { h.logger.Error("Session missing") http.Error(w, http.StatusText(500), http.StatusInternalServerError) @@ -390,6 +393,7 @@ func (h CallbackHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } rawLandingURL := h.landingURL(session) + landingURL, err := url.Parse(rawLandingURL) if err != nil { h.logger.Error("Error parsing landing URL") @@ -431,13 +435,12 @@ func (h CallbackHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { zap.String("cookie", hash), zap.String("hash", shaAsString(returnedState))) - // This operation will delete all cookies from the session - session.IDToken = "" - session.UserID = uuid.Nil - auth.WriteSessionCookie(w, session, h.clientAuthSecretKey, h.noSessionTimeout, h.logger, h.useSecureCookie) // Delete lg_state cookie auth.DeleteCookie(w, StateCookieName(session)) + // This operation will delete all cookies from the session + h.sessionManager.Destroy(r.Context()) + // set error query landingQuery := landingURL.Query() landingQuery.Add("error", "SIGNIN_ERROR") @@ -473,6 +476,7 @@ func (h CallbackHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { session.IDToken = openIDSession.IDToken session.Email = openIDUser.Email + h.logger.Info("New Login", zap.String("OID_User", openIDUser.UserID), zap.String("OID_Email", openIDUser.Email), zap.String("Host", session.Hostname)) userIdentity, err := models.FetchUserIdentity(h.db, openIDUser.UserID) @@ -521,7 +525,13 @@ var authorizeUnknownUserNew = func(openIDUser goth.User, h CallbackHandler, sess return } } - auth.WriteSessionCookie(w, session, h.clientAuthSecretKey, h.noSessionTimeout, h.logger, h.useSecureCookie) + err = h.sessionManager.RenewToken(r.Context()) + if err != nil { + http.Error(w, err.Error(), 500) + return + } + h.sessionManager.Put(r.Context(), "session", session) + h.logger.Info("logged in", zap.Any("session", session)) http.Redirect(w, r, h.landingURL(session), http.StatusTemporaryRedirect) return } @@ -635,9 +645,15 @@ var authorizeKnownUserNew = func(userIdentity *models.UserIdentity, h CallbackHa session.LastName = userIdentity.LastName() session.Middle = userIdentity.Middle() + error := h.sessionManager.RenewToken(r.Context()) + if error != nil { + http.Error(w, error.Error(), http.StatusInternalServerError) + return + } + h.sessionManager.Put(r.Context(), "session", session) + h.logger.Info("logged in", zap.Any("session", session)) - auth.WriteSessionCookie(w, session, h.clientAuthSecretKey, h.noSessionTimeout, h.logger, h.useSecureCookie) http.Redirect(w, r, lURL, http.StatusTemporaryRedirect) } @@ -743,9 +759,17 @@ var authorizeKnownUser = func(userIdentity *models.UserIdentity, h CallbackHandl session.LastName = userIdentity.LastName() session.Middle = userIdentity.Middle() + // The session token must be renewed during sign in to prevent + // session fixation attacks + err := h.sessionManager.RenewToken(r.Context()) + if err != nil { + http.Error(w, err.Error(), 500) + return + } + h.sessionManager.Put(r.Context(), "session", session) + h.logger.Info("logged in", zap.Any("session", session)) - auth.WriteSessionCookie(w, session, h.clientAuthSecretKey, h.noSessionTimeout, h.logger, h.useSecureCookie) http.Redirect(w, r, lURL, http.StatusTemporaryRedirect) } @@ -814,9 +838,15 @@ var authorizeUnknownUser = func(openIDUser goth.User, h CallbackHandler, session return } + err = h.sessionManager.RenewToken(r.Context()) + if err != nil { + http.Error(w, err.Error(), 500) + return + } + h.sessionManager.Put(r.Context(), "session", session) + h.logger.Info("logged in", zap.Any("session", session)) - auth.WriteSessionCookie(w, session, h.clientAuthSecretKey, h.noSessionTimeout, h.logger, h.useSecureCookie) http.Redirect(w, r, h.landingURL(session), http.StatusTemporaryRedirect) } diff --git a/pkg/auth/authentication/auth_test.go b/pkg/auth/authentication/auth_test.go index 9e6cb26af7ec..59a5206bcade 100644 --- a/pkg/auth/authentication/auth_test.go +++ b/pkg/auth/authentication/auth_test.go @@ -1,6 +1,8 @@ package authentication import ( + "context" + "encoding/gob" "flag" "fmt" "log" @@ -9,7 +11,10 @@ import ( "net/url" "strconv" "testing" + "time" + "github.com/alexedwards/scs/v2" + "github.com/alexedwards/scs/v2/memstore" "github.com/markbates/goth" middleware "github.com/go-openapi/runtime/middleware" @@ -19,13 +24,11 @@ import ( "github.com/stretchr/testify/suite" "go.uber.org/zap" - "github.com/transcom/mymove/pkg/testdatagen" - "github.com/transcom/mymove/pkg/auth" "github.com/transcom/mymove/pkg/models" - "github.com/transcom/mymove/pkg/testingsuite" - "github.com/transcom/mymove/pkg/models/roles" + "github.com/transcom/mymove/pkg/testdatagen" + "github.com/transcom/mymove/pkg/testingsuite" ) const ( @@ -97,6 +100,7 @@ func (suite *AuthSuite) SetupTest() { if *useNewAuth { authorizeUnknownUser = authorizeUnknownUserNew } + gob.Register(auth.Session{}) } func TestAuthSuite(t *testing.T) { @@ -116,6 +120,23 @@ func fakeLoginGovProvider(logger Logger) LoginGovProvider { return NewLoginGovProvider("fakeHostname", "secret_key", logger) } +func setupScsSession(ctx context.Context, session *auth.Session) (context.Context, *scs.SessionManager) { + var sessionManager *scs.SessionManager + sessionManager = scs.New() + store := memstore.New() + sessionManager.Store = store + + values := make(map[string]interface{}) + values["session"] = session + expiry := time.Now().Add(30 * time.Minute).UTC() + b, _ := sessionManager.Codec.Encode(expiry, values) + + store.Commit("session_token", b, expiry) + scsContext, _ := sessionManager.Load(ctx, "session_token") + sessionManager.Commit(scsContext) + return scsContext, sessionManager +} + func (suite *AuthSuite) TestGenerateNonce() { t := suite.T() nonce := generateNonce() @@ -141,9 +162,11 @@ func (suite *AuthSuite) TestAuthorizationLogoutHandler() { } ctx := auth.SetSessionInRequestContext(req, &session) req = req.WithContext(ctx) + var sessionManager *scs.SessionManager + sessionManager = scs.New() authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) - handler := LogoutHandler{authContext, "fake key", false, false} + handler := sessionManager.LoadAndSave(LogoutHandler{authContext, sessionManager}) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req.WithContext(ctx)) @@ -188,7 +211,9 @@ func (suite *AuthSuite) TestRequireAuthMiddleware() { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handlerSession = auth.SessionFromRequestContext(r) }) - middleware := UserAuthMiddleware(suite.logger)(handler) + var sessionManager *scs.SessionManager + sessionManager = scs.New() + middleware := sessionManager.LoadAndSave(UserAuthMiddleware(suite.logger)(handler)) middleware.ServeHTTP(rr, req) @@ -201,7 +226,9 @@ func (suite *AuthSuite) TestIsLoggedInWhenNoUserLoggedIn() { req := httptest.NewRequest("GET", "/is_logged_in", nil) rr := httptest.NewRecorder() - handler := http.HandlerFunc(IsLoggedInMiddleware(suite.logger)) + var sessionManager *scs.SessionManager + sessionManager = scs.New() + handler := sessionManager.LoadAndSave(IsLoggedInMiddleware(suite.logger)) handler.ServeHTTP(rr, req) @@ -224,13 +251,15 @@ func (suite *AuthSuite) TestIsLoggedInWhenUserLoggedIn() { req := httptest.NewRequest("GET", "/is_logged_in", nil) + var sessionManager *scs.SessionManager + sessionManager = scs.New() // And: the context contains the auth values session := auth.Session{UserID: user.ID, IDToken: "fake Token"} ctx := auth.SetSessionInRequestContext(req, &session) req = req.WithContext(ctx) rr := httptest.NewRecorder() - handler := IsLoggedInMiddleware(suite.logger) + handler := sessionManager.LoadAndSave(IsLoggedInMiddleware(suite.logger)) handler.ServeHTTP(rr, req) @@ -250,7 +279,9 @@ func (suite *AuthSuite) TestRequireAuthMiddlewareUnauthorized() { req := httptest.NewRequest("GET", "/moves", nil) handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) - middleware := UserAuthMiddleware(suite.logger)(handler) + var sessionManager *scs.SessionManager + sessionManager = scs.New() + middleware := sessionManager.LoadAndSave(UserAuthMiddleware(suite.logger)(handler)) middleware.ServeHTTP(rr, req) @@ -329,12 +360,12 @@ func (suite *AuthSuite) TestAuthorizeDeactivateUser() { ctx := auth.SetSessionInRequestContext(req, &session) callbackPort := 1234 authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) + var sessionManager *scs.SessionManager + sessionManager = scs.New() h := CallbackHandler{ authContext, suite.DB(), - "fake key", - false, - false, + sessionManager, } rr := httptest.NewRecorder() authorizeKnownUser(&userIdentity, h, &session, rr, req.WithContext(ctx), "") @@ -362,15 +393,16 @@ func (suite *AuthSuite) TestAuthKnownSingleRoleOffice() { ctx := auth.SetSessionInRequestContext(req, &session) callbackPort := 1234 authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) + + scsContext, sessionManager := setupScsSession(ctx, &session) + h := CallbackHandler{ authContext, suite.DB(), - "fake key", - false, - false, + sessionManager, } rr := httptest.NewRecorder() - authorizeKnownUser(&userIdentity, h, &session, rr, req.WithContext(ctx), "") + authorizeKnownUser(&userIdentity, h, &session, rr, req.WithContext(scsContext), "") // Office app, so should only have office ID information suite.Equal(officeUserID, session.OfficeUserID) @@ -396,12 +428,12 @@ func (suite *AuthSuite) TestAuthorizeDeactivateOfficeUser() { ctx := auth.SetSessionInRequestContext(req, &session) callbackPort := 1234 authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) + var sessionManager *scs.SessionManager + sessionManager = scs.New() h := CallbackHandler{ authContext, suite.DB(), - "fake key", - false, - false, + sessionManager, } rr := httptest.NewRecorder() authorizeKnownUser(&userIdentity, h, &session, rr, req.WithContext(ctx), "") @@ -409,64 +441,6 @@ func (suite *AuthSuite) TestAuthorizeDeactivateOfficeUser() { suite.Equal(http.StatusForbidden, rr.Code, "authorizer did not recognize deactivated office user") } -func (suite *AuthSuite) TestRedirectLoginGovErrorMsg() { - officeUserID := uuid.Must(uuid.NewV4()) - userIdentity := models.UserIdentity{ - Active: true, - OfficeUserID: &officeUserID, - } - - req := httptest.NewRequest("GET", fmt.Sprintf("http://%s/login-gov/callback", OfficeTestHost), nil) - - fakeToken := "some_token" - fakeUUID, _ := uuid.FromString("39b28c92-0506-4bef-8b57-e39519f42dc2") - session := auth.Session{ - ApplicationName: auth.OfficeApp, - UserID: fakeUUID, - IDToken: fakeToken, - Hostname: OfficeTestHost, - } - // login.gov state cookie - cookieName := StateCookieName(&session) - cookie := http.Cookie{ - Name: cookieName, - Value: "some mis-matched hash value", - Path: "/", - Expires: auth.GetExpiryTimeFromMinutes(auth.SessionExpiryInMinutes), - } - req.AddCookie(&cookie) - - ctx := auth.SetSessionInRequestContext(req, &session) - callbackPort := 1234 - authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) - h := CallbackHandler{ - authContext, - suite.DB(), - "fake key", - false, - false, - } - rr := httptest.NewRecorder() - authorizeKnownUser(&userIdentity, h, &session, rr, req.WithContext(ctx), "") - - rr2 := httptest.NewRecorder() - h.ServeHTTP(rr2, req.WithContext(ctx)) - - // Office app, so should only have office ID information - suite.Equal(officeUserID, session.OfficeUserID) - - suite.Equal(2, len(rr2.Result().Cookies())) - // check for blank value for cookie login gov state value and the session cookie value - for _, cookie := range rr2.Result().Cookies() { - if cookie.Name == cookieName || cookie.Name == fmt.Sprintf("%s_%s", string(session.ApplicationName), auth.UserSessionCookieName) { - suite.Equal("blank", cookie.Value) - suite.Equal("/", cookie.Path) - } - } - - suite.Equal("http://office.example.com:1234/?error=SIGNIN_ERROR", rr2.Result().Header.Get("Location")) -} - func (suite *AuthSuite) TestAuthKnownSingleRoleAdmin() { adminUserID := uuid.Must(uuid.NewV4()) officeUserID := uuid.Must(uuid.NewV4()) @@ -493,15 +467,16 @@ func (suite *AuthSuite) TestAuthKnownSingleRoleAdmin() { ctx := auth.SetSessionInRequestContext(req, &session) callbackPort := 1234 authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) + + scsContext, sessionManager := setupScsSession(ctx, &session) + h := CallbackHandler{ authContext, suite.DB(), - "fake key", - false, - false, + sessionManager, } rr := httptest.NewRecorder() - authorizeKnownUser(&userIdentity, h, &session, rr, req.WithContext(ctx), "") + authorizeKnownUser(&userIdentity, h, &session, rr, req.WithContext(scsContext), "") // admin app, so should only have admin ID information suite.Equal(adminUserID, session.AdminUserID) @@ -531,12 +506,12 @@ func (suite *AuthSuite) TestAuthorizeDeactivateAdmin() { ctx := auth.SetSessionInRequestContext(req, &session) callbackPort := 1234 authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) + var sessionManager *scs.SessionManager + sessionManager = scs.New() h := CallbackHandler{ authContext, suite.DB(), - "fake key", - false, - false, + sessionManager, } rr := httptest.NewRecorder() authorizeKnownUser(&userIdentity, h, &session, rr, req.WithContext(ctx), "") @@ -568,12 +543,12 @@ func (suite *AuthSuite) TestAuthorizeUnknownUserOfficeDeactivated() { callbackPort := 1234 authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) + var sessionManager *scs.SessionManager + sessionManager = scs.New() h := CallbackHandler{ authContext, suite.DB(), - "fake key", - false, - false, + sessionManager, } rr := httptest.NewRecorder() @@ -604,12 +579,12 @@ func (suite *AuthSuite) TestAuthorizeUnknownUserOfficeNotFound() { callbackPort := 1234 authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) + var sessionManager *scs.SessionManager + sessionManager = scs.New() h := CallbackHandler{ authContext, suite.DB(), - "fake key", - false, - false, + sessionManager, } rr := httptest.NewRecorder() @@ -651,16 +626,17 @@ func (suite *AuthSuite) TestAuthorizeUnknownUserOfficeLogsIn() { callbackPort := 1234 authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) + + scsContext, sessionManager := setupScsSession(ctx, &session) + h := CallbackHandler{ authContext, suite.DB(), - "fake key", - false, - false, + sessionManager, } rr := httptest.NewRecorder() - authorizeUnknownUser(user, h, &session, rr, req.WithContext(ctx), "") + authorizeUnknownUser(user, h, &session, rr, req.WithContext(scsContext), "") // Office app, so should only have office ID information suite.Equal(officeUser.ID, session.OfficeUserID) @@ -687,12 +663,12 @@ func (suite *AuthSuite) TestAuthorizeUnknownUserAdminDeactivated() { callbackPort := 1234 authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) + var sessionManager *scs.SessionManager + sessionManager = scs.New() h := CallbackHandler{ authContext, suite.DB(), - "fake key", - false, - false, + sessionManager, } rr := httptest.NewRecorder() @@ -723,12 +699,12 @@ func (suite *AuthSuite) TestAuthorizeUnknownUserAdminNotFound() { callbackPort := 1234 authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) + var sessionManager *scs.SessionManager + sessionManager = scs.New() h := CallbackHandler{ authContext, suite.DB(), - "fake key", - false, - false, + sessionManager, } rr := httptest.NewRecorder() @@ -764,16 +740,17 @@ func (suite *AuthSuite) TestAuthorizeUnknownUserAdminLogsIn() { callbackPort := 1234 authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) + + scsContext, sessionManager := setupScsSession(ctx, &session) + h := CallbackHandler{ authContext, suite.DB(), - FakeRSAKey, - false, - false, + sessionManager, } rr := httptest.NewRecorder() - authorizeUnknownUser(user, h, &session, rr, req.WithContext(ctx), "") + authorizeUnknownUser(user, h, &session, rr, req.WithContext(scsContext), "") // Office app, so should only have office ID information suite.Equal(adminUser.ID, session.AdminUserID) diff --git a/pkg/auth/authentication/authgch_test.go b/pkg/auth/authentication/authgch_test.go index 248793efd17a..b77e30d96f39 100644 --- a/pkg/auth/authentication/authgch_test.go +++ b/pkg/auth/authentication/authgch_test.go @@ -5,6 +5,8 @@ import ( "net/http" "net/http/httptest" + "github.com/alexedwards/scs/v2" + "github.com/transcom/mymove/pkg/models/roles" "github.com/transcom/mymove/pkg/cli" @@ -49,17 +51,16 @@ func (suite *AuthSuite) TestCreateTOO() { req.AddCookie(&cookie) callbackPort := 1234 authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) + var sessionManager *scs.SessionManager + sessionManager = scs.New() h := CallbackHandler{ authContext, suite.DB(), - FakeRSAKey, - false, - false, + sessionManager, } rr := httptest.NewRecorder() h.SetFeatureFlag(FeatureFlag{Name: cli.FeatureFlagRoleBasedAuth, Active: true}) - h.ServeHTTP(rr, req) - + sessionManager.LoadAndSave(h).ServeHTTP(rr, req) suite.Equal(rr.Code, 307) } diff --git a/pkg/auth/authentication/authghc.go b/pkg/auth/authentication/authghc.go index 79a397005544..c0d076a7c144 100644 --- a/pkg/auth/authentication/authghc.go +++ b/pkg/auth/authentication/authghc.go @@ -137,7 +137,6 @@ func (uua UnknownUserAuthorizer) AuthorizeUnknownUser(openIDUser goth.User, sess return err } } - uua.logger.Info("logged in", zap.Any("session", session)) return nil } diff --git a/pkg/auth/authentication/devlocal.go b/pkg/auth/authentication/devlocal.go index 497504a9c936..8f101c60a162 100644 --- a/pkg/auth/authentication/devlocal.go +++ b/pkg/auth/authentication/devlocal.go @@ -8,6 +8,7 @@ import ( "strings" "time" + "github.com/alexedwards/scs/v2" "github.com/gobuffalo/pop" "github.com/gofrs/uuid" "github.com/gorilla/csrf" @@ -33,19 +34,15 @@ const ( type UserListHandler struct { db *pop.Connection Context - clientAuthSecretKey string - noSessionTimeout bool - useSecureCookie bool + sessionManager *scs.SessionManager } // NewUserListHandler returns a new UserListHandler -func NewUserListHandler(ac Context, db *pop.Connection, clientAuthSecretKey string, noSessionTimeout bool, useSecureCookie bool) UserListHandler { +func NewUserListHandler(ac Context, db *pop.Connection, sessionManager *scs.SessionManager) UserListHandler { handler := UserListHandler{ - Context: ac, - db: db, - clientAuthSecretKey: clientAuthSecretKey, - noSessionTimeout: noSessionTimeout, - useSecureCookie: useSecureCookie, + Context: ac, + db: db, + sessionManager: sessionManager, } return handler } @@ -57,9 +54,7 @@ func (h UserListHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // User is already authenticated, so clear out their current session and have // them try again. This the issue where a developer will get stuck with a stale // session and have to manually clear cookies to get back to the login page. - session.IDToken = "" - session.UserID = uuid.Nil - auth.WriteSessionCookie(w, session, h.clientAuthSecretKey, h.noSessionTimeout, h.logger, h.useSecureCookie) + h.sessionManager.Destroy(r.Context()) auth.DeleteCSRFCookies(w) http.Redirect(w, r, h.landingURL(session), http.StatusTemporaryRedirect) @@ -199,25 +194,21 @@ func (h UserListHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { type devlocalAuthHandler struct { Context - db *pop.Connection - appnames auth.ApplicationServername - clientAuthSecretKey string - noSessionTimeout bool - useSecureCookie bool + db *pop.Connection + appnames auth.ApplicationServername + sessionManager *scs.SessionManager } // AssignUserHandler logs a user in directly type AssignUserHandler devlocalAuthHandler // NewAssignUserHandler creates a new AssignUserHandler -func NewAssignUserHandler(ac Context, db *pop.Connection, appnames auth.ApplicationServername, clientAuthSecretKey string, noSessionTimeout bool, useSecureCookie bool) AssignUserHandler { +func NewAssignUserHandler(ac Context, db *pop.Connection, appnames auth.ApplicationServername, sessionManager *scs.SessionManager) AssignUserHandler { handler := AssignUserHandler{ - Context: ac, - db: db, - appnames: appnames, - clientAuthSecretKey: clientAuthSecretKey, - noSessionTimeout: noSessionTimeout, - useSecureCookie: useSecureCookie, + Context: ac, + db: db, + appnames: appnames, + sessionManager: sessionManager, } return handler } @@ -269,14 +260,12 @@ func (h AssignUserHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { type CreateUserHandler devlocalAuthHandler // NewCreateUserHandler creates a new CreateUserHandler -func NewCreateUserHandler(ac Context, db *pop.Connection, appnames auth.ApplicationServername, clientAuthSecretKey string, noSessionTimeout bool, useSecureCookie bool) CreateUserHandler { +func NewCreateUserHandler(ac Context, db *pop.Connection, appnames auth.ApplicationServername, sessionManager *scs.SessionManager) CreateUserHandler { handler := CreateUserHandler{ - Context: ac, - db: db, - appnames: appnames, - clientAuthSecretKey: clientAuthSecretKey, - noSessionTimeout: noSessionTimeout, - useSecureCookie: useSecureCookie, + Context: ac, + db: db, + appnames: appnames, + sessionManager: sessionManager, } return handler } @@ -302,14 +291,12 @@ func (h CreateUserHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { type CreateAndLoginUserHandler devlocalAuthHandler // NewCreateAndLoginUserHandler creates a new CreateAndLoginUserHandler -func NewCreateAndLoginUserHandler(ac Context, db *pop.Connection, appnames auth.ApplicationServername, clientAuthSecretKey string, noSessionTimeout bool, useSecureCookie bool) CreateAndLoginUserHandler { +func NewCreateAndLoginUserHandler(ac Context, db *pop.Connection, appnames auth.ApplicationServername, sessionManager *scs.SessionManager) CreateAndLoginUserHandler { handler := CreateAndLoginUserHandler{ - Context: ac, - db: db, - appnames: appnames, - clientAuthSecretKey: clientAuthSecretKey, - noSessionTimeout: noSessionTimeout, - useSecureCookie: useSecureCookie, + Context: ac, + db: db, + appnames: appnames, + sessionManager: sessionManager, } return handler } @@ -535,10 +522,10 @@ func createSession(h devlocalAuthHandler, user *models.User, userType string, w session.LastName = userIdentity.LastName() session.Middle = userIdentity.Middle() + h.sessionManager.Cookie.Name = auth.SessionCookieName(session) + h.sessionManager.Put(r.Context(), "session", session) // Writing out the session cookie logs in the user h.logger.Info("logged in", zap.Any("session", session)) - auth.WriteSessionCookie(w, session, h.clientAuthSecretKey, h.noSessionTimeout, h.logger, h.useSecureCookie) - return session, nil } diff --git a/pkg/auth/authentication/devlocal_test.go b/pkg/auth/authentication/devlocal_test.go index e522c51984ec..865dbfd05f88 100644 --- a/pkg/auth/authentication/devlocal_test.go +++ b/pkg/auth/authentication/devlocal_test.go @@ -8,6 +8,7 @@ import ( "net/url" "strings" + "github.com/alexedwards/scs/v2" "github.com/pkg/errors" "github.com/transcom/mymove/pkg/models" @@ -37,11 +38,13 @@ func (suite *AuthSuite) TestCreateUserHandlerMilMove() { req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") req.ParseForm() + var sessionManager *scs.SessionManager + sessionManager = scs.New() authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) - handler := NewCreateUserHandler(authContext, suite.DB(), appnames, FakeRSAKey, false, false) + handler := NewCreateUserHandler(authContext, suite.DB(), appnames, sessionManager) rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req) + sessionManager.LoadAndSave(handler).ServeHTTP(rr, req) suite.Equal(http.StatusOK, rr.Code, "handler returned wrong status code") if status := rr.Code; status != http.StatusOK { @@ -78,11 +81,13 @@ func (suite *AuthSuite) TestCreateUserHandlerOffice() { req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") req.ParseForm() + var sessionManager *scs.SessionManager + sessionManager = scs.New() authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) - handler := NewCreateUserHandler(authContext, suite.DB(), appnames, FakeRSAKey, false, false) + handler := NewCreateUserHandler(authContext, suite.DB(), appnames, sessionManager) rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req) + sessionManager.LoadAndSave(handler).ServeHTTP(rr, req) suite.Equal(http.StatusOK, rr.Code, "handler returned wrong status code") if status := rr.Code; status != http.StatusOK { @@ -123,11 +128,13 @@ func (suite *AuthSuite) TestCreateUserHandlerDPS() { req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") req.ParseForm() + var sessionManager *scs.SessionManager + sessionManager = scs.New() authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) - handler := NewCreateUserHandler(authContext, suite.DB(), appnames, FakeRSAKey, false, false) + handler := NewCreateUserHandler(authContext, suite.DB(), appnames, sessionManager) rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req) + sessionManager.LoadAndSave(handler).ServeHTTP(rr, req) suite.Equal(http.StatusOK, rr.Code, "handler returned wrong status code") if status := rr.Code; status != http.StatusOK { @@ -164,11 +171,13 @@ func (suite *AuthSuite) TestCreateUserHandlerAdmin() { req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") req.ParseForm() + var sessionManager *scs.SessionManager + sessionManager = scs.New() authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) - handler := NewCreateUserHandler(authContext, suite.DB(), appnames, FakeRSAKey, false, false) + handler := NewCreateUserHandler(authContext, suite.DB(), appnames, sessionManager) rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req) + sessionManager.LoadAndSave(handler).ServeHTTP(rr, req) suite.Equal(http.StatusOK, rr.Code, "handler returned wrong status code") if status := rr.Code; status != http.StatusOK { @@ -210,11 +219,13 @@ func (suite *AuthSuite) TestCreateAndLoginUserHandlerFromMilMoveToMilMove() { req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") req.ParseForm() + var sessionManager *scs.SessionManager + sessionManager = scs.New() authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) - handler := NewCreateAndLoginUserHandler(authContext, suite.DB(), appnames, FakeRSAKey, false, false) + handler := NewCreateAndLoginUserHandler(authContext, suite.DB(), appnames, sessionManager) rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req) + sessionManager.LoadAndSave(handler).ServeHTTP(rr, req) suite.Equal(http.StatusSeeOther, rr.Code, "handler returned wrong status code") if status := rr.Code; status != http.StatusSeeOther { @@ -242,11 +253,13 @@ func (suite *AuthSuite) TestCreateAndLoginUserHandlerFromMilMoveToOffice() { req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") req.ParseForm() + var sessionManager *scs.SessionManager + sessionManager = scs.New() authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) - handler := NewCreateAndLoginUserHandler(authContext, suite.DB(), appnames, FakeRSAKey, false, false) + handler := NewCreateAndLoginUserHandler(authContext, suite.DB(), appnames, sessionManager) rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req) + sessionManager.LoadAndSave(handler).ServeHTTP(rr, req) suite.Equal(http.StatusSeeOther, rr.Code, "handler returned wrong status code") if status := rr.Code; status != http.StatusSeeOther { @@ -274,11 +287,13 @@ func (suite *AuthSuite) TestCreateAndLoginUserHandlerFromMilMoveToAdmin() { req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") req.ParseForm() + var sessionManager *scs.SessionManager + sessionManager = scs.New() authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) - handler := NewCreateAndLoginUserHandler(authContext, suite.DB(), appnames, FakeRSAKey, false, false) + handler := NewCreateAndLoginUserHandler(authContext, suite.DB(), appnames, sessionManager) rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req) + sessionManager.LoadAndSave(handler).ServeHTTP(rr, req) suite.Equal(http.StatusSeeOther, rr.Code, "handler returned wrong status code") if status := rr.Code; status != http.StatusSeeOther { @@ -306,11 +321,13 @@ func (suite *AuthSuite) TestCreateAndLoginUserHandlerFromOfficeToMilMove() { req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") req.ParseForm() + var sessionManager *scs.SessionManager + sessionManager = scs.New() authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) - handler := NewCreateAndLoginUserHandler(authContext, suite.DB(), appnames, FakeRSAKey, false, false) + handler := NewCreateAndLoginUserHandler(authContext, suite.DB(), appnames, sessionManager) rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req) + sessionManager.LoadAndSave(handler).ServeHTTP(rr, req) suite.Equal(http.StatusSeeOther, rr.Code, "handler returned wrong status code") if status := rr.Code; status != http.StatusSeeOther { @@ -338,11 +355,13 @@ func (suite *AuthSuite) TestCreateAndLoginUserHandlerFromOfficeToAdmin() { req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") req.ParseForm() + var sessionManager *scs.SessionManager + sessionManager = scs.New() authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) - handler := NewCreateAndLoginUserHandler(authContext, suite.DB(), appnames, FakeRSAKey, false, false) + handler := NewCreateAndLoginUserHandler(authContext, suite.DB(), appnames, sessionManager) rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req) + sessionManager.LoadAndSave(handler).ServeHTTP(rr, req) suite.Equal(http.StatusSeeOther, rr.Code, "handler returned wrong status code") if status := rr.Code; status != http.StatusSeeOther { @@ -370,11 +389,13 @@ func (suite *AuthSuite) TestCreateAndLoginUserHandlerFromAdminToMilMove() { req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") req.ParseForm() + var sessionManager *scs.SessionManager + sessionManager = scs.New() authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) - handler := NewCreateAndLoginUserHandler(authContext, suite.DB(), appnames, FakeRSAKey, false, false) + handler := NewCreateAndLoginUserHandler(authContext, suite.DB(), appnames, sessionManager) rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req) + sessionManager.LoadAndSave(handler).ServeHTTP(rr, req) suite.Equal(http.StatusSeeOther, rr.Code, "handler returned wrong status code") if status := rr.Code; status != http.StatusSeeOther { @@ -402,11 +423,13 @@ func (suite *AuthSuite) TestCreateAndLoginUserHandlerFromAdminToOffice() { req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") req.ParseForm() + var sessionManager *scs.SessionManager + sessionManager = scs.New() authContext := NewAuthContext(suite.logger, fakeLoginGovProvider(suite.logger), "http", callbackPort) - handler := NewCreateAndLoginUserHandler(authContext, suite.DB(), appnames, FakeRSAKey, false, false) + handler := NewCreateAndLoginUserHandler(authContext, suite.DB(), appnames, sessionManager) rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req) + sessionManager.LoadAndSave(handler).ServeHTTP(rr, req) suite.Equal(http.StatusSeeOther, rr.Code, "handler returned wrong status code") if status := rr.Code; status != http.StatusSeeOther { diff --git a/pkg/auth/cookie.go b/pkg/auth/cookie.go index e2838650af95..29fa36b19d32 100644 --- a/pkg/auth/cookie.go +++ b/pkg/auth/cookie.go @@ -6,8 +6,7 @@ import ( "strings" "time" - jwt "github.com/dgrijalva/jwt-go" - "github.com/gofrs/uuid" + "github.com/alexedwards/scs/v2" "github.com/gorilla/csrf" "github.com/pkg/errors" "go.uber.org/zap" @@ -48,12 +47,6 @@ const MaskedGorillaCSRFToken = "masked_gorilla_csrf" // SessionExpiryInMinutes is the number of minutes before a fallow session is harvested const SessionExpiryInMinutes = 15 -const sessionExpiryInSeconds = 15 * 60 - -// A representable date far in the future. The trouble with something like https://stackoverflow.com/a/32620397 -// is that it produces a date which may not marshall well into JSON which makes logging problematic -var likeForever = time.Date(9999, 1, 1, 12, 0, 0, 0, time.UTC) -var likeForeverInSeconds = 99999999 // GetExpiryTimeFromMinutes returns 'min' minutes from now func GetExpiryTimeFromMinutes(min int64) time.Time { @@ -81,59 +74,9 @@ func DeleteCookie(w http.ResponseWriter, name string) { http.SetCookie(w, &c) } -// SessionClaims wraps StandardClaims with some Session info -type SessionClaims struct { - jwt.StandardClaims - SessionValue Session -} - -func signTokenStringWithUserInfo(expiry time.Time, session *Session, secret string) (string, error) { - claims := SessionClaims{ - StandardClaims: jwt.StandardClaims{ExpiresAt: expiry.Unix()}, - SessionValue: *session, - } - - token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) - rsaKey, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(secret)) - if err != nil { - err = errors.Wrap(err, "Parsing RSA key from PEM") - return "", err - } - - ss, err := token.SignedString(rsaKey) - if err != nil { - err = errors.Wrap(err, "Signing string with token") - return "", err - } - return ss, err -} - -func sessionClaimsFromRequest(logger Logger, secret string, appName Application, r *http.Request) (claims *SessionClaims, ok bool) { - // Name the cookie with the app name - cookieName := fmt.Sprintf("%s_%s", string(appName), UserSessionCookieName) - cookie, err := r.Cookie(cookieName) - if err != nil { - // No cookie set on client - return - } - - token, err := jwt.ParseWithClaims(cookie.Value, &SessionClaims{}, func(token *jwt.Token) (interface{}, error) { - rsaKey, parseRSAPrivateKeyFromPEMErr := jwt.ParseRSAPrivateKeyFromPEM([]byte(secret)) - return &rsaKey.PublicKey, parseRSAPrivateKeyFromPEMErr - }) - - if err != nil || token == nil || !token.Valid { - logger.Error("Failed token validation", zap.Error(err)) - return - } - - // The token actually just stores a Claims interface, so we need to explicitly cast back to UserClaims - claims, ok = token.Claims.(*SessionClaims) - if !ok { - logger.Error("Failed getting claims from token") - return - } - return claims, ok +// SessionCookieName returns the session cookie name +func SessionCookieName(session *Session) string { + return fmt.Sprintf("%s_%s", string(session.ApplicationName), UserSessionCookieName) } // WriteMaskedCSRFCookie update the masked_gorilla_csrf cookie value @@ -171,48 +114,6 @@ func MaskedCSRFMiddleware(logger Logger, useSecureCookie bool) func(next http.Ha } } -// WriteSessionCookie update the cookie for the session -func WriteSessionCookie(w http.ResponseWriter, session *Session, secret string, noSessionTimeout bool, logger Logger, useSecureCookie bool) { - // Delete the cookie - cookieName := fmt.Sprintf("%s_%s", string(session.ApplicationName), UserSessionCookieName) - cookie := http.Cookie{ - Name: cookieName, - Value: "blank", - Path: "/", - Expires: time.Unix(0, 0), - MaxAge: -1, - HttpOnly: true, - SameSite: http.SameSiteLaxMode, // Using 'lax' mode now since 'strict' breaks the use of the login.gov redirect - Secure: useSecureCookie, - } - - // unless we have a valid session - if session.IDToken != "" && session.UserID != uuid.Nil { - expiry := GetExpiryTimeFromMinutes(SessionExpiryInMinutes) - maxAge := sessionExpiryInSeconds - // Never expire token if in development - if noSessionTimeout { - expiry = likeForever - maxAge = likeForeverInSeconds - } - - ss, err := signTokenStringWithUserInfo(expiry, session, secret) - if err != nil { - logger.Error("Generating signed token string", zap.Error(err)) - } else { - logger.Info("Cookie", zap.Int("Size", len(ss))) - cookie.Value = ss - cookie.Expires = expiry - cookie.MaxAge = maxAge - } - } - // http.SetCookie calls Header().Add() instead of .Set(), which can result in duplicate cookies - // It's ok to use this here because we want to delete and rewrite `Set-Cookie` on login or if the - // session token is lost. However, we would normally use http.SetCookie for any other cookie operations - // so as not to delete the session token. - w.Header().Set("Set-Cookie", cookie.String()) -} - // ApplicationName returns the application name given the hostname func ApplicationName(hostname string, appnames ApplicationServername) (Application, error) { var appName Application @@ -233,7 +134,7 @@ func ApplicationName(hostname string, appnames ApplicationServername) (Applicati } // SessionCookieMiddleware handle serializing and de-serializing the session between the user_session cookie and the request context -func SessionCookieMiddleware(serverLogger Logger, secret string, noSessionTimeout bool, appnames ApplicationServername, useSecureCookie bool) func(next http.Handler) http.Handler { +func SessionCookieMiddleware(serverLogger Logger, appnames ApplicationServername, sessionManager *scs.SessionManager) func(next http.Handler) http.Handler { serverLogger.Info("Creating session", zap.String("milServername", appnames.MilServername), zap.String("officeServername", appnames.OfficeServername), @@ -261,9 +162,10 @@ func SessionCookieMiddleware(serverLogger Logger, secret string, noSessionTimeou http.Error(w, http.StatusText(400), http.StatusBadRequest) return } - claims, ok := sessionClaimsFromRequest(logger, secret, appName, r) - if ok { - session = claims.SessionValue + + existingSession := sessionManager.Get(r.Context(), "session") + if existingSession != nil { + session = existingSession.(Session) } // Set more information on the session @@ -271,7 +173,8 @@ func SessionCookieMiddleware(serverLogger Logger, secret string, noSessionTimeou session.Hostname = strings.ToLower(hostname) // And update the cookie. May get over-ridden later - WriteSessionCookie(w, &session, secret, noSessionTimeout, logger, useSecureCookie) + sessionManager.Cookie.Name = SessionCookieName(&session) + sessionManager.Put(r.Context(), "session", session) // And put the session info into the request context next.ServeHTTP(w, r.WithContext(SetSessionInContext(ctx, &session))) diff --git a/pkg/auth/cookie_test.go b/pkg/auth/cookie_test.go index 43cbeec086b8..0fa5e0f79794 100644 --- a/pkg/auth/cookie_test.go +++ b/pkg/auth/cookie_test.go @@ -1,35 +1,18 @@ package auth import ( - "crypto/rand" - "crypto/rsa" - "crypto/x509" - "encoding/pem" + "encoding/gob" "fmt" "net/http" "net/http/httptest" "strings" "time" - "github.com/gofrs/uuid" - "github.com/pkg/errors" + "github.com/alexedwards/scs/v2" ) -func createRandomRSAPEM() (s string, err error) { - priv, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - err = errors.Wrap(err, "failed to generate key") - return - } - - asn1 := x509.MarshalPKCS1PrivateKey(priv) - privBytes := pem.EncodeToMemory(&pem.Block{ - Type: "RSA PRIVATE KEY", - Bytes: asn1, - }) - s = string(privBytes[:]) - - return +func (suite *authSuite) SetupTest() { + gob.Register(Session{}) } func getHandlerParamsWithToken(ss string, expiry time.Time) (*httptest.ResponseRecorder, *http.Request) { @@ -52,24 +35,21 @@ func getHandlerParamsWithToken(ss string, expiry time.Time) (*httptest.ResponseR } func (suite *authSuite) TestSessionCookieMiddlewareWithBadToken() { - t := suite.T() fakeToken := "some_token" - pem, err := createRandomRSAPEM() - if err != nil { - t.Error("error creating RSA key", err) - } + var sessionManager *scs.SessionManager + sessionManager = scs.New() var resultingSession *Session handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { resultingSession = SessionFromRequestContext(r) }) appnames := ApplicationTestServername() - middleware := SessionCookieMiddleware(suite.logger, pem, false, appnames, false)(handler) + middleware := SessionCookieMiddleware(suite.logger, appnames, sessionManager)(handler) expiry := GetExpiryTimeFromMinutes(SessionExpiryInMinutes) rr, req := getHandlerParamsWithToken(fakeToken, expiry) - middleware.ServeHTTP(rr, req) + sessionManager.LoadAndSave(middleware).ServeHTTP(rr, req) // We should be not be redirected since we're not enforcing auth suite.Equal(http.StatusOK, rr.Code, "handler returned wrong status code") @@ -79,142 +59,6 @@ func (suite *authSuite) TestSessionCookieMiddlewareWithBadToken() { suite.Equal("", resultingSession.IDToken, "Expected empty IDToken from bad cookie") } -func (suite *authSuite) TestSessionCookieMiddlewareWithValidToken() { - t := suite.T() - email := "some_email@domain.com" - idToken := "fake_id_token" - fakeUUID, _ := uuid.FromString("39b28c92-0506-4bef-8b57-e39519f42dc2") - - pem, err := createRandomRSAPEM() - if err != nil { - t.Fatal(err) - } - - expiry := GetExpiryTimeFromMinutes(SessionExpiryInMinutes) - incomingSession := Session{ - UserID: fakeUUID, - Email: email, - IDToken: idToken, - } - ss, err := signTokenStringWithUserInfo(expiry, &incomingSession, pem) - if err != nil { - t.Fatal(err) - } - rr, req := getHandlerParamsWithToken(ss, expiry) - - var resultingSession *Session - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - resultingSession = SessionFromRequestContext(r) - }) - appnames := ApplicationTestServername() - middleware := SessionCookieMiddleware(suite.logger, pem, false, appnames, false)(handler) - - middleware.ServeHTTP(rr, req) - - // We should get a 200 OK - suite.Equal(http.StatusOK, rr.Code, "handler returned wrong status code") - - // And there should be an ID token in the request context - suite.NotNil(resultingSession) - suite.Equal(idToken, resultingSession.IDToken, "handler returned wrong id_token") - - // And the cookie should be renewed - setCookies := rr.HeaderMap["Set-Cookie"] - suite.Equal(1, len(setCookies), "expected cookie to be set") -} - -func (suite *authSuite) TestSessionCookieMiddlewareWithExpiredToken() { - t := suite.T() - email := "some_email@domain.com" - idToken := "fake_id_token" - fakeUUID, _ := uuid.FromString("39b28c92-0506-4bef-8b57-e39519f42dc2") - - pem, err := createRandomRSAPEM() - if err != nil { - t.Fatal(err) - } - - expiry := GetExpiryTimeFromMinutes(-1) - incomingSession := Session{ - UserID: fakeUUID, - Email: email, - IDToken: idToken, - } - ss, err := signTokenStringWithUserInfo(expiry, &incomingSession, pem) - if err != nil { - t.Fatal(err) - } - - var resultingSession *Session - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - resultingSession = SessionFromRequestContext(r) - }) - appnames := ApplicationTestServername() - middleware := SessionCookieMiddleware(suite.logger, pem, false, appnames, false)(handler) - - rr, req := getHandlerParamsWithToken(ss, expiry) - - middleware.ServeHTTP(rr, req) - - // We should be not be redirected since we're not enforcing auth - suite.Equal(http.StatusOK, rr.Code, "handler returned wrong status code") - - // And there should be no token passed through - // And there should be no token passed through - suite.NotNil(resultingSession) - suite.Equal("", resultingSession.IDToken, "Expected empty IDToken from expired") - suite.Equal(uuid.Nil, resultingSession.UserID, "Expected no UUID from expired cookie") - - // And the cookie should be set - setCookies := rr.HeaderMap["Set-Cookie"] - suite.Equal(1, len(setCookies), "expected cookie to be set") -} - -func (suite *authSuite) TestSessionCookiePR161162731() { - t := suite.T() - email := "some_email@domain.com" - idToken := "fake_id_token" - fakeUUID, _ := uuid.FromString("39b28c92-0506-4bef-8b57-e39519f42dc2") - - pem, err := createRandomRSAPEM() - if err != nil { - t.Fatal(err) - } - - expiry := GetExpiryTimeFromMinutes(SessionExpiryInMinutes) - incomingSession := Session{ - UserID: fakeUUID, - Email: email, - IDToken: idToken, - } - ss, err := signTokenStringWithUserInfo(expiry, &incomingSession, pem) - if err != nil { - t.Fatal(err) - } - rr, req := getHandlerParamsWithToken(ss, expiry) - - var resultingSession *Session - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - resultingSession = SessionFromRequestContext(r) - WriteSessionCookie(w, resultingSession, "freddy", false, suite.logger, false) - }) - appnames := ApplicationTestServername() - middleware := SessionCookieMiddleware(suite.logger, pem, false, appnames, false)(handler) - - middleware.ServeHTTP(rr, req) - - // We should get a 200 OK - suite.Equal(http.StatusOK, rr.Code, "handler returned wrong status code") - - // And there should be an ID token in the request context - suite.NotNil(resultingSession) - suite.Equal(idToken, resultingSession.IDToken, "handler returned wrong id_token") - - // And the cookie should be renewed - setCookies := rr.HeaderMap["Set-Cookie"] - suite.Equal(1, len(setCookies), "expected cookie to be set") -} - func (suite *authSuite) TestMaskedCSRFMiddleware() { rr := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/", nil) @@ -261,7 +105,9 @@ func (suite *authSuite) TestMaskedCSRFMiddlewareCreatesNewToken() { func (suite *authSuite) TestMiddlewareConstructor() { appnames := ApplicationTestServername() - adm := SessionCookieMiddleware(suite.logger, "secret", false, appnames, false) + var sessionManager *scs.SessionManager + sessionManager = scs.New() + adm := SessionCookieMiddleware(suite.logger, appnames, sessionManager) suite.NotNil(adm) } @@ -276,16 +122,18 @@ func (suite *authSuite) TestMiddlewareMilApp() { suite.False(session.IsAdminApp(), "first should not be admin app") suite.Equal(appnames.MilServername, session.Hostname) }) - milMoveMiddleware := SessionCookieMiddleware(suite.logger, "secret", false, appnames, false)(milMoveTestHandler) + var sessionManager *scs.SessionManager + sessionManager = scs.New() + milMoveMiddleware := SessionCookieMiddleware(suite.logger, appnames, sessionManager)(milMoveTestHandler) req := httptest.NewRequest("GET", fmt.Sprintf("http://%s/some_url", appnames.MilServername), nil) - milMoveMiddleware.ServeHTTP(rr, req) + sessionManager.LoadAndSave(milMoveMiddleware).ServeHTTP(rr, req) req, _ = http.NewRequest("GET", fmt.Sprintf("http://%s:8080/some_url", appnames.MilServername), nil) - milMoveMiddleware.ServeHTTP(rr, req) + sessionManager.LoadAndSave(milMoveMiddleware).ServeHTTP(rr, req) req, _ = http.NewRequest("GET", fmt.Sprintf("http://%s:8080/some_url", strings.ToUpper(appnames.MilServername)), nil) - milMoveMiddleware.ServeHTTP(rr, req) + sessionManager.LoadAndSave(milMoveMiddleware).ServeHTTP(rr, req) } func (suite *authSuite) TestMiddlwareOfficeApp() { @@ -299,16 +147,18 @@ func (suite *authSuite) TestMiddlwareOfficeApp() { suite.False(session.IsAdminApp(), "should not be admin app") suite.Equal(appnames.OfficeServername, session.Hostname) }) - officeMiddleware := SessionCookieMiddleware(suite.logger, "secret", false, appnames, false)(officeTestHandler) + var sessionManager *scs.SessionManager + sessionManager = scs.New() + officeMiddleware := SessionCookieMiddleware(suite.logger, appnames, sessionManager)(officeTestHandler) req := httptest.NewRequest("GET", fmt.Sprintf("http://%s/some_url", appnames.OfficeServername), nil) - officeMiddleware.ServeHTTP(rr, req) + sessionManager.LoadAndSave(officeMiddleware).ServeHTTP(rr, req) req, _ = http.NewRequest("GET", fmt.Sprintf("http://%s:8080/some_url", appnames.OfficeServername), nil) - officeMiddleware.ServeHTTP(rr, req) + sessionManager.LoadAndSave(officeMiddleware).ServeHTTP(rr, req) req, _ = http.NewRequest("GET", fmt.Sprintf("http://%s:8080/some_url", strings.ToUpper(appnames.OfficeServername)), nil) - officeMiddleware.ServeHTTP(rr, req) + sessionManager.LoadAndSave(officeMiddleware).ServeHTTP(rr, req) } func (suite *authSuite) TestMiddlwareAdminApp() { @@ -322,16 +172,18 @@ func (suite *authSuite) TestMiddlwareAdminApp() { suite.True(session.IsAdminApp(), "should be admin app") suite.Equal(AdminTestHost, session.Hostname) }) - adminMiddleware := SessionCookieMiddleware(suite.logger, "secret", false, appnames, false)(adminTestHandler) + var sessionManager *scs.SessionManager + sessionManager = scs.New() + adminMiddleware := SessionCookieMiddleware(suite.logger, appnames, sessionManager)(adminTestHandler) req := httptest.NewRequest("GET", fmt.Sprintf("http://%s/some_url", AdminTestHost), nil) - adminMiddleware.ServeHTTP(rr, req) + sessionManager.LoadAndSave(adminMiddleware).ServeHTTP(rr, req) req, _ = http.NewRequest("GET", fmt.Sprintf("http://%s:8080/some_url", AdminTestHost), nil) - adminMiddleware.ServeHTTP(rr, req) + sessionManager.LoadAndSave(adminMiddleware).ServeHTTP(rr, req) req, _ = http.NewRequest("GET", fmt.Sprintf("http://%s:8080/some_url", strings.ToUpper(AdminTestHost)), nil) - adminMiddleware.ServeHTTP(rr, req) + sessionManager.LoadAndSave(adminMiddleware).ServeHTTP(rr, req) } func (suite *authSuite) TestMiddlewareBadApp() { @@ -341,9 +193,11 @@ func (suite *authSuite) TestMiddlewareBadApp() { suite.Fail("Should not be called") }) appnames := ApplicationTestServername() - noAppMiddleware := SessionCookieMiddleware(suite.logger, "secret", false, appnames, false)(noAppTestHandler) + var sessionManager *scs.SessionManager + sessionManager = scs.New() + noAppMiddleware := SessionCookieMiddleware(suite.logger, appnames, sessionManager)(noAppTestHandler) req := httptest.NewRequest("GET", "http://totally.bogus.hostname/some_url", nil) - noAppMiddleware.ServeHTTP(rr, req) + sessionManager.LoadAndSave(noAppMiddleware).ServeHTTP(rr, req) suite.Equal(http.StatusBadRequest, rr.Code, "Should get an error ") } diff --git a/pkg/handlers/internalapi/api.go b/pkg/handlers/internalapi/api.go index 5db0018ef18a..f37319b0e1a8 100644 --- a/pkg/handlers/internalapi/api.go +++ b/pkg/handlers/internalapi/api.go @@ -8,6 +8,7 @@ import ( movedocument "github.com/transcom/mymove/pkg/services/move_documents" postalcodeservice "github.com/transcom/mymove/pkg/services/postal_codes" + "github.com/alexedwards/scs/v2" "github.com/go-openapi/loads" "github.com/go-openapi/runtime" "github.com/pkg/errors" @@ -18,7 +19,7 @@ import ( ) // NewInternalAPI returns the internal API -func NewInternalAPI(context handlers.HandlerContext) *internalops.MymoveAPI { +func NewInternalAPI(context handlers.HandlerContext, sessionManager *scs.SessionManager) *internalops.MymoveAPI { internalSpec, err := loads.Analyzed(internalapi.SwaggerJSON, "") if err != nil { @@ -69,7 +70,7 @@ func NewInternalAPI(context handlers.HandlerContext) *internalops.MymoveAPI { internalAPI.MoveDocsCreateWeightTicketDocumentHandler = CreateWeightTicketSetDocumentHandler{context} - internalAPI.ServiceMembersCreateServiceMemberHandler = CreateServiceMemberHandler{context} + internalAPI.ServiceMembersCreateServiceMemberHandler = CreateServiceMemberHandler{context, sessionManager} internalAPI.ServiceMembersPatchServiceMemberHandler = PatchServiceMemberHandler{context} internalAPI.ServiceMembersShowServiceMemberHandler = ShowServiceMemberHandler{context} internalAPI.ServiceMembersShowServiceMemberOrdersHandler = ShowServiceMemberOrdersHandler{context} diff --git a/pkg/handlers/internalapi/service_members.go b/pkg/handlers/internalapi/service_members.go index 5add7efbac96..2e43b60b9ffa 100644 --- a/pkg/handlers/internalapi/service_members.go +++ b/pkg/handlers/internalapi/service_members.go @@ -3,6 +3,7 @@ package internalapi import ( "context" + "github.com/alexedwards/scs/v2" "github.com/go-openapi/runtime/middleware" "github.com/gobuffalo/validate" "github.com/gofrs/uuid" @@ -69,6 +70,7 @@ func payloadForServiceMemberModel(storer storage.FileStorer, serviceMember model // CreateServiceMemberHandler creates a new service member via POST /serviceMember type CreateServiceMemberHandler struct { handlers.HandlerContext + sessionManager *scs.SessionManager } // Handle ... creates a new ServiceMember from a request payload @@ -148,10 +150,12 @@ func (h CreateServiceMemberHandler) Handle(params servicememberop.CreateServiceM if newServiceMember.LastName != nil { session.LastName = *(newServiceMember.LastName) } + // Update session cookie here instead of in responders? + // h.sessionManager.Put(ctx, "session", session) // And return serviceMemberPayload := payloadForServiceMemberModel(h.FileStorer(), newServiceMember, h.HandlerContext.GetFeatureFlag(cli.FeatureFlagAccessCode)) responder := servicememberop.NewCreateServiceMemberCreated().WithPayload(serviceMemberPayload) - return handlers.NewCookieUpdateResponder(params.HTTPRequest, h.CookieSecret(), h.NoSessionTimeout(), logger, responder, h.UseSecureCookie()) + return handlers.NewCookieUpdateResponder(params.HTTPRequest, logger, responder, h.sessionManager, session) } // ShowServiceMemberHandler returns a serviceMember for a user and service member ID diff --git a/pkg/handlers/internalapi/service_members_test.go b/pkg/handlers/internalapi/service_members_test.go index 481c3ec6fc51..d3bcb9f55b12 100644 --- a/pkg/handlers/internalapi/service_members_test.go +++ b/pkg/handlers/internalapi/service_members_test.go @@ -6,6 +6,7 @@ import ( "net/http" "net/http/httptest" + "github.com/alexedwards/scs/v2" "github.com/go-openapi/strfmt" "github.com/go-openapi/swag" "github.com/gofrs/uuid" @@ -85,8 +86,9 @@ func (suite *HandlerSuite) TestSubmitServiceMemberHandlerNoValues() { CreateServiceMemberPayload: &newServiceMemberPayload, HTTPRequest: req, } - - handler := CreateServiceMemberHandler{handlers.NewHandlerContext(suite.DB(), suite.TestLogger())} + var sessionManager *scs.SessionManager + sessionManager = scs.New() + handler := CreateServiceMemberHandler{handlers.NewHandlerContext(suite.DB(), suite.TestLogger()), sessionManager} response := handler.Handle(params) suite.Assertions.IsType(&handlers.CookieUpdateResponder{}, response) @@ -151,7 +153,9 @@ func (suite *HandlerSuite) TestSubmitServiceMemberHandlerAllValues() { HTTPRequest: req, } - handler := CreateServiceMemberHandler{handlers.NewHandlerContext(suite.DB(), suite.TestLogger())} + var sessionManager *scs.SessionManager + sessionManager = scs.New() + handler := CreateServiceMemberHandler{handlers.NewHandlerContext(suite.DB(), suite.TestLogger()), sessionManager} response := handler.Handle(params) suite.Assertions.IsType(&handlers.CookieUpdateResponder{}, response) @@ -190,7 +194,9 @@ func (suite *HandlerSuite) TestSubmitServiceMemberSSN() { HTTPRequest: req, } - handler := CreateServiceMemberHandler{handlers.NewHandlerContext(suite.DB(), suite.TestLogger())} + var sessionManager *scs.SessionManager + sessionManager = scs.New() + handler := CreateServiceMemberHandler{handlers.NewHandlerContext(suite.DB(), suite.TestLogger()), sessionManager} response := handler.Handle(params) suite.Assertions.IsType(&handlers.CookieUpdateResponder{}, response) diff --git a/pkg/handlers/responders.go b/pkg/handlers/responders.go index 4ba7220b5387..2bb870b5ef31 100644 --- a/pkg/handlers/responders.go +++ b/pkg/handlers/responders.go @@ -1,8 +1,10 @@ package handlers import ( + "context" "net/http" + "github.com/alexedwards/scs/v2" "github.com/go-openapi/runtime" "github.com/go-openapi/runtime/middleware" @@ -12,28 +14,26 @@ import ( // CookieUpdateResponder wraps a swagger middleware.Responder in code which sets the session_cookie // See: https://github.com/go-swagger/go-swagger/issues/748 type CookieUpdateResponder struct { - session *auth.Session - cookieSecret string - noSessionTimeout bool - logger Logger - Responder middleware.Responder - useSecureCookie bool + session *auth.Session + logger Logger + Responder middleware.Responder + sessionManager *scs.SessionManager + ctx context.Context } // NewCookieUpdateResponder constructs a wrapper for the responder which will update cookies -func NewCookieUpdateResponder(request *http.Request, secret string, noSessionTimeout bool, logger Logger, responder middleware.Responder, useSecureCookie bool) middleware.Responder { +func NewCookieUpdateResponder(request *http.Request, logger Logger, responder middleware.Responder, sessionManager *scs.SessionManager, session *auth.Session) middleware.Responder { return &CookieUpdateResponder{ - session: auth.SessionFromRequestContext(request), - cookieSecret: secret, - noSessionTimeout: noSessionTimeout, - logger: logger, - Responder: responder, - useSecureCookie: useSecureCookie, + session: session, + logger: logger, + Responder: responder, + sessionManager: sessionManager, + ctx: request.Context(), } } // WriteResponse updates the session cookie before writing out the details of the response func (cur *CookieUpdateResponder) WriteResponse(rw http.ResponseWriter, p runtime.Producer) { - auth.WriteSessionCookie(rw, cur.session, cur.cookieSecret, cur.noSessionTimeout, cur.logger, cur.useSecureCookie) + cur.sessionManager.Put(cur.ctx, "session", cur.session) cur.Responder.WriteResponse(rw, p) } diff --git a/pkg/middleware/request_logger.go b/pkg/middleware/request_logger.go index fd7a85649954..8c0fb39950a0 100644 --- a/pkg/middleware/request_logger.go +++ b/pkg/middleware/request_logger.go @@ -63,6 +63,12 @@ func RequestLogger(serverLogger Logger) func(inner http.Handler) http.Handler { if session := auth.SessionFromContext(ctx); session != nil { if session.UserID != uuid.Nil { fields = append(fields, zap.String("user-id", session.UserID.String())) + var sessionID string + cookie, err := r.Cookie(auth.SessionCookieName(session)) + if err == nil { + sessionID = cookie.Value + } + fields = append(fields, zap.String("session-id", sessionID)) } if session.IsServiceMember() { fields = append(fields, zap.String("service-member-id", session.ServiceMemberID.String()))