diff --git a/postgrest.cabal b/postgrest.cabal index f06fe783b5..72de8dd18f 100644 --- a/postgrest.cabal +++ b/postgrest.cabal @@ -142,7 +142,6 @@ library , timeit >= 2.0 && < 2.1 , unordered-containers >= 0.2.8 && < 0.3 , unix-compat >= 0.5.4 && < 0.8 - , vault >= 0.3.1.5 && < 0.4 , vector >= 0.11 && < 0.14 , wai >= 3.2.1 && < 3.3 , wai-cors >= 0.2.5 && < 0.3 diff --git a/src/PostgREST/App.hs b/src/PostgREST/App.hs index bc696929d3..2e887a4ffd 100644 --- a/src/PostgREST/App.hs +++ b/src/PostgREST/App.hs @@ -99,13 +99,8 @@ postgrest :: LogLevel -> AppState.AppState -> IO () -> Wai.Application postgrest logLevel appState connWorker = traceHeaderMiddleware appState . Cors.middleware appState . - Auth.middleware appState . - Logger.middleware logLevel Auth.getRole $ - -- fromJust can be used, because the auth middleware will **always** add - -- some AuthResult to the vault. - \req respond -> case fromJust $ Auth.getResult req of - Left err -> respond $ Error.errorResponseFor err - Right authResult -> do + Logger.middleware logLevel $ + \req respond -> do appConf <- AppState.getConfig appState -- the config must be read again because it can reload maybeSchemaCache <- AppState.getSchemaCache appState pgVer <- AppState.getPgVersion appState @@ -113,7 +108,7 @@ postgrest logLevel appState connWorker = let eitherResponse :: IO (Either Error Wai.Response) eitherResponse = - runExceptT $ postgrestResponse appState appConf maybeSchemaCache pgVer authResult req + runExceptT $ postgrestResponse appState appConf maybeSchemaCache pgVer req response <- either Error.errorResponseFor identity <$> eitherResponse -- Launch the connWorker when the connection is down. The postgrest @@ -130,10 +125,9 @@ postgrestResponse -> AppConfig -> Maybe SchemaCache -> PgVersion - -> AuthResult -> Wai.Request -> Handler IO Wai.Response -postgrestResponse appState conf@AppConfig{..} maybeSchemaCache pgVer authResult@AuthResult{..} req = do +postgrestResponse appState conf@AppConfig{..} maybeSchemaCache pgVer req = do sCache <- case maybeSchemaCache of Just sCache -> @@ -143,13 +137,20 @@ postgrestResponse appState conf@AppConfig{..} maybeSchemaCache pgVer authResult@ body <- lift $ Wai.strictRequestBody req - let jwtTime = if configServerTimingEnabled then Auth.getJwtDur req else Nothing - timezones = dbTimezones sCache - prefs = ApiRequest.userPreferences conf req timezones + -- API-REQUEST/PARSE STAGE + let prefs = ApiRequest.userPreferences conf req (dbTimezones sCache) (parseTime, apiReq@ApiRequest{..}) <- withTiming $ liftEither . mapLeft Error.ApiRequestError $ ApiRequest.userApiRequest conf prefs req body - (planTime, plan) <- withTiming $ liftEither $ Plan.actionPlan iAction conf apiReq sCache + -- JWT/AUTH STAGE + (jwtTime, authResult@AuthResult{..}) <- withTiming $ do + eitherAuthResult <- liftIO $ Auth.getAuthResult appState apiReq + liftEither eitherAuthResult + + -- PLAN STAGE + (planTime, plan) <- withTiming $ liftEither $ Plan.actionPlan iAction conf apiReq sCache + + -- QUERY/TRANSACTION STAGE let query = Query.query conf authResult apiReq plan sCache pgVer logSQL = lift . AppState.getObserver appState . DBQuery (Query.getSQLQuery query) @@ -162,6 +163,7 @@ postgrestResponse appState conf@AppConfig{..} maybeSchemaCache pgVer authResult@ when (configLogQuery /= LogQueryDisabled) $ whenLeft eitherResp $ logSQL . Error.status liftEither eitherResp >>= liftEither + -- RESPONSE STAGE (respTime, resp) <- withTiming $ do let response = Response.actionResponse queryResult apiReq (T.decodeUtf8 prettyVersion, docsVersion) conf sCache iSchema iNegotiatedByProfile when (configLogQuery /= LogQueryDisabled) $ logSQL $ either Error.status Response.pgrstStatus response diff --git a/src/PostgREST/Auth.hs b/src/PostgREST/Auth.hs index cbcef6b5ec..b64f972cc9 100644 --- a/src/PostgREST/Auth.hs +++ b/src/PostgREST/Auth.hs @@ -12,11 +12,8 @@ very simple authentication system inside the PostgreSQL database. -} {-# LANGUAGE RecordWildCards #-} module PostgREST.Auth - ( getResult - , getJwtDur - , getRole - , middleware - ) where + ( getAuthResult ) + where import qualified Data.Aeson as JSON import qualified Data.Aeson.Key as K @@ -25,14 +22,13 @@ import qualified Data.Aeson.Types as JSON import qualified Data.ByteString as BS import qualified Data.ByteString.Internal as BS import qualified Data.ByteString.Lazy.Char8 as LBS +import qualified Data.CaseInsensitive as CI import qualified Data.Scientific as Sci import qualified Data.Text as T -import qualified Data.Vault.Lazy as Vault import qualified Data.Vector as V import qualified Jose.Jwk as JWT import qualified Jose.Jwt as JWT import qualified Network.HTTP.Types.Header as HTTP -import qualified Network.Wai as Wai import qualified Network.Wai.Middleware.HttpAuth as Wai import Control.Monad.Except (liftEither) @@ -40,9 +36,8 @@ import Data.Either.Combinators (mapLeft) import Data.List (lookup) import Data.Time.Clock (UTCTime, nominalDiffTimeToSeconds) import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds) -import System.IO.Unsafe (unsafePerformIO) -import System.TimeIt (timeItT) +import PostgREST.ApiRequest (ApiRequest (..)) import PostgREST.AppState (AppState, getConfig, getJwtCacheState, getTime) import PostgREST.Auth.JwtCache (lookupJwtCache) @@ -131,11 +126,12 @@ parseClaims AppConfig{..} jclaims@(JSON.Object mclaims) = do walkJSPath x [] = x walkJSPath (Just (JSON.Object o)) (JSPKey key:rest) = walkJSPath (KM.lookup (K.fromText key) o) rest walkJSPath (Just (JSON.Array ar)) (JSPIdx idx:rest) = walkJSPath (ar V.!? idx) rest - walkJSPath (Just (JSON.Array ar)) [JSPFilter (EqualsCond txt)] = findFirstMatch (==) txt ar - walkJSPath (Just (JSON.Array ar)) [JSPFilter (NotEqualsCond txt)] = findFirstMatch (/=) txt ar - walkJSPath (Just (JSON.Array ar)) [JSPFilter (StartsWithCond txt)] = findFirstMatch T.isPrefixOf txt ar - walkJSPath (Just (JSON.Array ar)) [JSPFilter (EndsWithCond txt)] = findFirstMatch T.isSuffixOf txt ar - walkJSPath (Just (JSON.Array ar)) [JSPFilter (ContainsCond txt)] = findFirstMatch T.isInfixOf txt ar + walkJSPath (Just (JSON.Array ar)) [JSPFilter filterCond] = case filterCond of + EqualsCond txt -> findFirstMatch (==) txt ar + NotEqualsCond txt -> findFirstMatch (/=) txt ar + StartsWithCond txt -> findFirstMatch T.isPrefixOf txt ar + EndsWithCond txt -> findFirstMatch T.isSuffixOf txt ar + ContainsCond txt -> findFirstMatch T.isInfixOf txt ar walkJSPath _ _ = Nothing findFirstMatch matchWith pattern = foldr checkMatch Nothing @@ -151,55 +147,21 @@ parseClaims AppConfig{..} jclaims@(JSON.Object mclaims) = do -- impossible case - just added to please -Wincomplete-patterns parseClaims _ _ = return AuthResult { authClaims = KM.empty, authRole = mempty } --- | Validate authorization header. --- Parse and store JWT claims for future use in the request. -middleware :: AppState -> Wai.Middleware -middleware appState app req respond = do +-- | Perform authentication and authorization +-- Parse JWT and return AuthResult +getAuthResult :: AppState -> ApiRequest -> IO (Either Error AuthResult) +getAuthResult appState ApiRequest{..} = do conf <- getConfig appState time <- getTime appState - let token = Wai.extractBearerAuth =<< lookup HTTP.hAuthorization (Wai.requestHeaders req) + let ciHdrs = map (first CI.mk) iHeaders + token = Wai.extractBearerAuth =<< lookup HTTP.hAuthorization ciHdrs parseJwt = runExceptT $ parseToken conf token time >>= parseClaims conf jwtCacheState = getJwtCacheState appState --- If ServerTimingEnabled -> calculate JWT validation time --- If JwtCacheMaxLifetime -> cache JWT validation result - req' <- case (configServerTimingEnabled conf, configJwtCacheMaxLifetime conf) of - (True, 0) -> do - (dur, authResult) <- timeItT parseJwt - return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult & Vault.insert jwtDurKey dur } - - (True, maxLifetime) -> do - (dur, authResult) <- timeItT $ case token of - Just tkn -> lookupJwtCache jwtCacheState tkn maxLifetime parseJwt time - Nothing -> parseJwt - return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult & Vault.insert jwtDurKey dur } - - (False, 0) -> do - authResult <- parseJwt - return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult } - - (False, maxLifetime) -> do - authResult <- case token of - Just tkn -> lookupJwtCache jwtCacheState tkn maxLifetime parseJwt time - Nothing -> parseJwt - return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult } - - app req' respond - -authResultKey :: Vault.Key (Either Error AuthResult) -authResultKey = unsafePerformIO Vault.newKey -{-# NOINLINE authResultKey #-} - -getResult :: Wai.Request -> Maybe (Either Error AuthResult) -getResult = Vault.lookup authResultKey . Wai.vault - -jwtDurKey :: Vault.Key Double -jwtDurKey = unsafePerformIO Vault.newKey -{-# NOINLINE jwtDurKey #-} - -getJwtDur :: Wai.Request -> Maybe Double -getJwtDur = Vault.lookup jwtDurKey . Wai.vault - -getRole :: Wai.Request -> Maybe BS.ByteString -getRole req = authRole <$> (rightToMaybe =<< getResult req) + case configJwtCacheMaxLifetime conf of + 0 -> parseJwt -- If 0 then cache is diabled; no lookup + maxLifetime -> case token of + -- Lookup only if token found in header + Just tkn -> lookupJwtCache jwtCacheState tkn maxLifetime parseJwt time + Nothing -> parseJwt diff --git a/src/PostgREST/Logger.hs b/src/PostgREST/Logger.hs index dac4092234..b9466a8919 100644 --- a/src/PostgREST/Logger.hs +++ b/src/PostgREST/Logger.hs @@ -10,10 +10,9 @@ module PostgREST.Logger , LoggerState ) where -import Control.AutoUpdate (defaultUpdateSettings, - mkAutoUpdate, updateAction) -import Control.Debounce -import qualified Data.ByteString.Char8 as BS +import Control.AutoUpdate (defaultUpdateSettings, mkAutoUpdate, + updateAction) +import Control.Debounce import Data.Time (ZonedTime, defaultTimeLocale, formatTime, getZonedTime) @@ -56,15 +55,14 @@ logWithDebounce loggerState action = do newDebouncer -- TODO stop using this middleware to reuse the same "observer" pattern for all our logs -middleware :: LogLevel -> (Wai.Request -> Maybe BS.ByteString) -> Wai.Middleware -middleware logLevel getAuthRole = +middleware :: LogLevel -> Wai.Middleware +middleware logLevel = unsafePerformIO $ Wai.mkRequestLogger Wai.defaultRequestLoggerSettings { Wai.outputFormat = Wai.ApacheWithSettings $ Wai.defaultApacheSettings & - Wai.setApacheRequestFilter (\_ res -> shouldLogResponse logLevel $ Wai.responseStatus res) & - Wai.setApacheUserGetter getAuthRole + Wai.setApacheRequestFilter (\_ res -> shouldLogResponse logLevel $ Wai.responseStatus res) , Wai.autoFlush = True , Wai.destination = Wai.Handle stdout } diff --git a/test/io/test_io.py b/test/io/test_io.py index 72c2ff8927..ac05df5d32 100644 --- a/test/io/test_io.py +++ b/test/io/test_io.py @@ -959,6 +959,8 @@ def test_log_level(level, defaultenv): assert response.status_code == 200 output = sorted(postgrest.read_stdout(nlines=7)) + for line in output: + print(line) if level == "crit": assert len(output) == 0 @@ -974,35 +976,35 @@ def test_log_level(level, defaultenv): output[0], ) assert re.match( - r'- - postgrest_test_anonymous \[.+\] "GET /unknown HTTP/1.1" 404 \d+ "" "python-requests/.+"', + r'- - - \[.+\] "GET /unknown HTTP/1.1" 404 \d+ "" "python-requests/.+"', output[1], ) assert len(output) == 2 elif level == "info": assert re.match( - r'- - - \[.+\] "GET / HTTP/1.1" 500 \d+ "" "python-requests/.+"', + r'- - - \[.+\] "GET / HTTP/1.1" 200 \d+ "" "python-requests/.+"', output[0], ) assert re.match( - r'- - postgrest_test_anonymous \[.+\] "GET / HTTP/1.1" 200 \d+ "" "python-requests/.+"', + r'- - - \[.+\] "GET / HTTP/1.1" 500 \d+ "" "python-requests/.+"', output[1], ) assert re.match( - r'- - postgrest_test_anonymous \[.+\] "GET /unknown HTTP/1.1" 404 \d+ "" "python-requests/.+"', + r'- - - \[.+\] "GET /unknown HTTP/1.1" 404 \d+ "" "python-requests/.+"', output[2], ) assert len(output) == 3 elif level == "debug": assert re.match( - r'- - - \[.+\] "GET / HTTP/1.1" 500 \d+ "" "python-requests/.+"', + r'- - - \[.+\] "GET / HTTP/1.1" 200 \d+ "" "python-requests/.+"', output[0], ) assert re.match( - r'- - postgrest_test_anonymous \[.+\] "GET / HTTP/1.1" 200 \d+ "" "python-requests/.+"', + r'- - - \[.+\] "GET / HTTP/1.1" 500 \d+ "" "python-requests/.+"', output[1], ) assert re.match( - r'- - postgrest_test_anonymous \[.+\] "GET /unknown HTTP/1.1" 404 \d+ "" "python-requests/.+"', + r'- - - \[.+\] "GET /unknown HTTP/1.1" 404 \d+ "" "python-requests/.+"', output[2], )