Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 36 additions & 34 deletions src/Servant/Server/Experimental/Auth/Cookie.hs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ module Servant.Server.Experimental.Auth.Cookie
, mkPadding
, mkMAC
, applyCipherAlgorithm

-- reexports from Web.Cookie
, sameSiteLax
, sameSiteStrict
) where

import Blaze.ByteString.Builder (toByteString)
Expand Down Expand Up @@ -146,7 +150,6 @@ import qualified Crypto.MAC.HMAC as H
import qualified Data.ByteArray as BA
import qualified Data.ByteString as BS
import qualified Data.ByteString.Base64 as Base64
import qualified Data.ByteString.Char8 as BSC8
import qualified Data.Serialize as Serialize (encode, decode)
import qualified Network.HTTP.Types as N(Header)

Expand Down Expand Up @@ -198,11 +201,6 @@ instance Default ExpirationType where

instance Serialize ExpirationType

-- | Format used in 'Expires' cookie field.
expirationFormat :: String
expirationFormat = "%a, %d %b %Y %H:%M:%S GMT"


-- | Wrapper for session value that goes into cookies' payload.
data PayloadWrapper a = PayloadWrapper {
pwSession :: a
Expand Down Expand Up @@ -476,8 +474,12 @@ data AuthCookieSettings where
AuthCookieSettings :: (HashAlgorithm h, BlockCipher c) =>
{ acsSessionField :: ByteString
-- ^ Name of a cookie which stores session object
, acsCookieFlags :: [ByteString]
-- ^ Session cookie's flags
, acsHttpOnly :: Bool
-- ^ whether the cookie should be marked httponly, i.e. not accessible by JavaScript
, acsSecure :: Bool
-- ^ whether the cookie should be marked as secure, i.e. only transmitted over secure connections
, acsSameSite :: Maybe SameSiteOption
-- ^ whether the cookie should only be transmitted to the originating site
, acsMaxAge :: NominalDiffTime
-- ^ For how long the cookie will be valid (corresponds to “Max-Age”
-- or "Expires" attribute).
Expand All @@ -496,7 +498,9 @@ data AuthCookieSettings where
instance Default AuthCookieSettings where
def = AuthCookieSettings
{ acsSessionField = "Session"
, acsCookieFlags = ["HttpOnly", "Secure"]
, acsHttpOnly = True
, acsSecure = True
, acsSameSite = Nothing
, acsMaxAge = fromIntegral (12 * 3600 :: Integer) -- 12 hours
, acsPath = "/"
, acsHashAlgorithm = Proxy :: Proxy SHA256
Expand Down Expand Up @@ -697,41 +701,40 @@ parseSessionResponse acs hdrs = parseSession acs hSetCookie hdrs
renderSession'
:: AuthCookieSettings
-> (Tagged SerializedEncryptedCookie ByteString)
-> Maybe (ByteString, ByteString)
-> (Maybe UTCTime, Maybe DiffTime)
-> ByteString
renderSession' AuthCookieSettings{..} (Tagged sessionBinary) expiration
= toByteString . renderCookies
$ (acsSessionField, sessionBinary)
: ("Path", acsPath)
: ((maybe id (:) expiration)
$ ((,"") <$> acsCookieFlags))
renderSession' AuthCookieSettings{..} (Tagged sessionBinary) (cookieExpires, cookieMaxAge)
= (toByteString . renderSetCookie) $ defaultSetCookie
{ setCookieName = acsSessionField
, setCookieValue = sessionBinary
, setCookiePath = Just acsPath
, setCookieExpires = cookieExpires
, setCookieMaxAge = cookieMaxAge
, setCookieHttpOnly = acsHttpOnly
, setCookieSecure = acsSecure
, setCookieSameSite = acsSameSite
}

-- | Render session cookie to 'ByteString'.
renderSession :: AddSession () ByteString
renderSession acs rs sk pwSettings pwSession _ = liftM2 (renderSession' acs)
(encryptSession acs rs sk pwSettings pwSession)
(renderExpiration (acsMaxAge acs) (ssExpirationType pwSettings))
(renderExpiration' (acsMaxAge acs) (ssExpirationType pwSettings))

-- | Render expired session to 'ByteString' (the date is set at 0 and the content is wiped).
renderExpiredSession :: AuthCookieSettings -> ByteString
renderExpiredSession acs = renderSession' acs (Tagged "") (Just ("Expires", longTimeAgo)) where
longTimeAgo = BSC8.pack $ formatTime
defaultTimeLocale
expirationFormat
timeOrigin
renderExpiredSession acs = renderSession' acs (Tagged "") (Just timeOrigin, Nothing) where
timeOrigin = UTCTime (toEnum 0) 0

-- | Render expiration value depending on it's type.
renderExpiration :: (MonadIO m) => NominalDiffTime -> ExpirationType -> m (Maybe (ByteString, ByteString))

renderExpiration maxAge Expires
= liftM (addUTCTime maxAge) (liftIO getCurrentTime)
>>= \t -> return . Just $ ("Expires", BSC8.pack $ formatTime defaultTimeLocale expirationFormat t)

renderExpiration maxAge MaxAge = return . Just $ ("Max-Age", (BSC8.pack . show . n) maxAge)
where n = floor :: NominalDiffTime -> Int

renderExpiration _ Session = return Nothing
renderExpiration' :: (MonadIO m) => NominalDiffTime -> ExpirationType -> m (Maybe UTCTime, Maybe DiffTime)
renderExpiration' maxAge expirationType =
case expirationType of
Session -> return (Nothing, Nothing)
MaxAge ->
return (Nothing, Just (fromRational . toRational $ maxAge))
Expires -> do
expirationTime <- liftIO $ (addUTCTime maxAge) <$> getCurrentTime
return (Just expirationTime, Nothing)


#if MIN_VERSION_servant(0,9,1)
Expand Down Expand Up @@ -910,4 +913,3 @@ unProxy Proxy = undefined
-- | Generates random sequence of bytes from new DRG
generateRandomBytes :: Int -> IO ByteString
generateRandomBytes size = (fst . randomBytesGenerate size <$> drgNew)

2 changes: 1 addition & 1 deletion stack.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
resolver: lts-8.9
resolver: lts-9.13
packages:
- '.'
extra-deps:
Expand Down