From 380acb3efa5e0162fd87519abae893fc67da3244 Mon Sep 17 00:00:00 2001
From: "Julian K. Arni" <jkarni@gmail.com>
Date: Mon, 12 Jan 2015 15:08:41 +0100
Subject: [PATCH] Add Accept header handling.

---
 default.nix                             |  15 +++
 servant-server.cabal                    |   2 +
 src/Servant/Server/ContentTypes.hs      | 114 +++++++++++++++++++++
 src/Servant/Server/Internal.hs          |  53 ++++++----
 test/Servant/Server/ContentTypesSpec.hs | 129 ++++++++++++++++++++++++
 test/Servant/ServerSpec.hs              |  19 ++--
 test/Servant/Utils/StaticFilesSpec.hs   |   3 +-
 7 files changed, 308 insertions(+), 27 deletions(-)
 create mode 100644 default.nix
 create mode 100644 src/Servant/Server/ContentTypes.hs
 create mode 100644 test/Servant/Server/ContentTypesSpec.hs

diff --git a/default.nix b/default.nix
new file mode 100644
index 0000000..e8a420d
--- /dev/null
+++ b/default.nix
@@ -0,0 +1,15 @@
+{ pkgs ? import <nixpkgs> { config.allowUnfree = true; }
+, src ?  builtins.filterSource (path: type:
+    type != "unknown" &&
+    baseNameOf path != ".git" &&
+    baseNameOf path != "result" &&
+    baseNameOf path != "dist") ./.
+, servant ? import ../servant {}
+}:
+pkgs.haskellPackages.buildLocalCabalWithArgs {
+  name = "servant-server";
+  inherit src;
+  args = {
+      inherit servant;
+  };
+}
diff --git a/servant-server.cabal b/servant-server.cabal
index 7da021e..077f63e 100644
--- a/servant-server.cabal
+++ b/servant-server.cabal
@@ -31,6 +31,7 @@ library
   exposed-modules:
     Servant
     Servant.Server
+    Servant.Server.ContentTypes
     Servant.Server.Internal
     Servant.Utils.StaticFiles
   build-depends:
@@ -41,6 +42,7 @@ library
     , either >= 4.3
     , http-types
     , network-uri >= 2.6
+    , http-media == 0.4.*
     , safe
     , servant >= 0.2.2
     , split
diff --git a/src/Servant/Server/ContentTypes.hs b/src/Servant/Server/ContentTypes.hs
new file mode 100644
index 0000000..e714421
--- /dev/null
+++ b/src/Servant/Server/ContentTypes.hs
@@ -0,0 +1,114 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE MultiParamTypeClasses #-}
+{-# LANGUAGE OverloadedStrings #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+module Servant.Server.ContentTypes where
+
+import Data.Aeson (ToJSON(..), encode)
+import Data.ByteString.Lazy (ByteString)
+import qualified Data.ByteString as BS
+import Data.Proxy (Proxy(..))
+import Data.String.Conversions (cs)
+import qualified Network.HTTP.Media as M
+
+
+import Servant.API (XML, HTML, JSON, JavaScript, CSS, PlainText)
+
+-- | Instances of 'Accept' represent mimetypes. They are used for matching
+-- against the @Accept@ HTTP header of the request, and for setting the
+-- @Content-Type@ header of the response
+--
+-- Example:
+--
+--   instance Accept HTML where
+--      contentType _ = "text" // "html"
+--
+class Accept ctype where
+    contentType   :: Proxy ctype -> M.MediaType
+
+instance Accept HTML where
+    contentType _ = "text" M.// "html"
+
+instance Accept JSON where
+    contentType _ = "application" M.// "json"
+
+instance Accept XML where
+    contentType _ = "application" M.// "xml"
+
+instance Accept JavaScript where
+    contentType _ = "application" M.// "javascript"
+
+instance Accept CSS where
+    contentType _ = "text" M.// "css"
+
+instance Accept PlainText where
+    contentType _ = "text" M.// "plain"
+
+newtype AcceptHeader = AcceptHeader BS.ByteString
+    deriving (Eq, Show)
+
+-- | Instantiate this class to register a way of serializing a type based
+-- on the @Accept@ header.
+class Accept ctype => MimeRender ctype a where
+    toByteString  :: Proxy ctype -> a -> ByteString
+
+class AllCTRender list a where
+    -- If the Accept header can be matched, returns (Just) a tuple of the
+    -- Content-Type and response (serialization of @a@ into the appropriate
+    -- mimetype).
+    handleAcceptH :: Proxy list -> AcceptHeader -> a -> Maybe (ByteString, ByteString)
+
+instance ( AllMimeRender ctyps a, IsEmpty ctyps ~ 'False
+         ) => AllCTRender ctyps a where
+    handleAcceptH _ (AcceptHeader accept) val = M.mapAcceptMedia lkup accept
+      where pctyps = Proxy :: Proxy ctyps
+            amrs = amr pctyps val
+            lkup = zip (map fst amrs) $ map (\(a,b) -> (cs $ show a, b)) amrs
+
+
+--------------------------------------------------------------------------
+-- Check that all elements of list are instances of MimeRender
+--------------------------------------------------------------------------
+class AllMimeRender ls a where
+    amr :: Proxy ls -> a -> [(M.MediaType, ByteString)] -- list of content-types/response pairs
+
+instance ( MimeRender ctyp a ) => AllMimeRender '[ctyp] a where
+    amr _ a = [(contentType pctyp, toByteString pctyp a)]
+        where pctyp = Proxy :: Proxy ctyp
+
+instance ( MimeRender ctyp a
+         , MimeRender ctyp' a
+         , AllMimeRender ctyps a
+         ) => AllMimeRender (ctyp ': ctyp' ': ctyps) a where
+    amr _ a = (contentType pctyp, toByteString pctyp a)
+             :(contentType pctyp', toByteString pctyp' a)
+             :(amr pctyps a)
+        where pctyp = Proxy :: Proxy ctyp
+              pctyps = Proxy :: Proxy ctyps
+              pctyp' = Proxy :: Proxy ctyp'
+
+
+instance AllMimeRender '[] a where
+    amr _ _ = []
+
+type family IsEmpty (ls::[*]) where
+    IsEmpty '[] = 'True
+    IsEmpty x   = 'False
+
+--------------------------------------------------------------------------
+-- MimeRender Instances
+--------------------------------------------------------------------------
+
+instance ToJSON a => MimeRender JSON a where
+    toByteString _ = encode
+
+instance Show a => MimeRender PlainText a where
+    toByteString _ = encode . show
+
+instance MimeRender PlainText String where
+    toByteString _ = encode
diff --git a/src/Servant/Server/Internal.hs b/src/Servant/Server/Internal.hs
index 4bd0a08..033129b 100644
--- a/src/Servant/Server/Internal.hs
+++ b/src/Servant/Server/Internal.hs
@@ -24,10 +24,14 @@ import qualified Data.Text as T
 import Data.Typeable
 import GHC.TypeLits (KnownSymbol, symbolVal)
 import Network.HTTP.Types hiding (Header)
-import Network.Wai (Response, Request, ResponseReceived, Application, pathInfo, requestBody,
-                    strictRequestBody, lazyRequestBody, requestHeaders, requestMethod,
+import Network.Wai ( Response, Request, ResponseReceived, Application
+                   , pathInfo, requestBody, strictRequestBody
+                   , lazyRequestBody, requestHeaders, requestMethod,
                     rawQueryString, responseLBS)
-import Servant.API (QueryParams, QueryParam, QueryFlag, MatrixParams, MatrixParam, MatrixFlag, ReqBody, Header, Capture, Get, Delete, Put, Post, Patch, Raw, (:>), (:<|>)(..))
+import Servant.API ( QueryParams, QueryParam, QueryFlag, ReqBody, Header
+                   , MatrixParams, MatrixParam, MatrixFlag,
+                   , Capture, Get, Delete, Put, Post, Patch, Raw, (:>), (:<|>)(..))
+import Servant.Server.ContentTypes (AllCTRender(..), AcceptHeader(..))
 import Servant.Common.Text (FromText, fromText)
 
 data ReqBodyState = Uncalled
@@ -225,7 +229,7 @@ instance (KnownSymbol capture, FromText a, HasServer sublayout)
     _ -> respond $ failWith NotFound
 
     where captureProxy = Proxy :: Proxy (Capture capture a)
-           
+
 
 -- | If you have a 'Delete' endpoint in your API,
 -- the handler for this endpoint is meant to delete
@@ -264,14 +268,19 @@ instance HasServer Delete where
 -- If successfully returning a value, we just require that its type has
 -- a 'ToJSON' instance and servant takes care of encoding it for you,
 -- yielding status code 200 along the way.
-instance ToJSON result => HasServer (Get result) where
-  type Server (Get result) = EitherT (Int, String) IO result
+instance ( AllCTRender ctypes a, ToJSON a
+         ) => HasServer (Get ctypes a) where
+  type Server (Get ctypes a) = EitherT (Int, String) IO a
   route Proxy action request respond
     | pathIsEmpty request && requestMethod request == methodGet = do
         e <- runEitherT action
         respond . succeedWith $ case e of
-          Right output ->
-            responseLBS ok200 [("Content-Type", "application/json")] (encode output)
+          Right output -> do
+            let accH = fromMaybe "*/*" $ lookup hAccept $ requestHeaders request
+            case handleAcceptH (Proxy :: Proxy ctypes) (AcceptHeader accH) output of
+              Nothing -> responseLBS (mkStatus 406 "") [] ""
+              Just (contentT, body) -> responseLBS ok200 [ ("Content-Type"
+                                                         , cs contentT)] body
           Left (status, message) ->
             responseLBS (mkStatus status (cs message)) [] (cs message)
     | pathIsEmpty request && requestMethod request /= methodGet =
@@ -321,15 +330,20 @@ instance (KnownSymbol sym, FromText a, HasServer sublayout)
 -- If successfully returning a value, we just require that its type has
 -- a 'ToJSON' instance and servant takes care of encoding it for you,
 -- yielding status code 201 along the way.
-instance ToJSON a => HasServer (Post a) where
-  type Server (Post a) = EitherT (Int, String) IO a
+instance ( AllCTRender ctypes a, ToJSON a
+         )=> HasServer (Post ctypes a) where
+  type Server (Post ctypes a) = EitherT (Int, String) IO a
 
   route Proxy action request respond
     | pathIsEmpty request && requestMethod request == methodPost = do
         e <- runEitherT action
         respond . succeedWith $ case e of
-          Right out ->
-            responseLBS status201 [("Content-Type", "application/json")] (encode out)
+          Right output -> do
+            let accH = fromMaybe "*/*" $ lookup hAccept $ requestHeaders request
+            case handleAcceptH (Proxy :: Proxy ctypes) (AcceptHeader accH) output of
+              Nothing -> responseLBS (mkStatus 406 "") [] ""
+              Just (contentT, body) -> responseLBS status201 [ ("Content-Type"
+                                                             , cs contentT)] body
           Left (status, message) ->
             responseLBS (mkStatus status (cs message)) [] (cs message)
     | pathIsEmpty request && requestMethod request /= methodPost =
@@ -347,15 +361,20 @@ instance ToJSON a => HasServer (Post a) where
 -- If successfully returning a value, we just require that its type has
 -- a 'ToJSON' instance and servant takes care of encoding it for you,
 -- yielding status code 200 along the way.
-instance ToJSON a => HasServer (Put a) where
-  type Server (Put a) = EitherT (Int, String) IO a
+instance ( AllCTRender ctypes a, ToJSON a
+         ) => HasServer (Put ctypes a) where
+  type Server (Put ctypes a) = EitherT (Int, String) IO a
 
   route Proxy action request respond
     | pathIsEmpty request && requestMethod request == methodPut = do
         e <- runEitherT action
         respond . succeedWith $ case e of
-          Right out ->
-            responseLBS ok200 [("Content-Type", "application/json")] (encode out)
+          Right output -> do
+            let accH = fromMaybe "*/*" $ lookup hAccept $ requestHeaders request
+            case handleAcceptH (Proxy :: Proxy ctypes) (AcceptHeader accH) output of
+              Nothing -> responseLBS (mkStatus 406 "") [] ""
+              Just (contentT, body) -> responseLBS status200 [ ("Content-Type"
+                                                             , cs contentT)] body
           Left (status, message) ->
             responseLBS (mkStatus status (cs message)) [] (cs message)
     | pathIsEmpty request && requestMethod request /= methodPut =
@@ -382,7 +401,7 @@ instance (Typeable a, ToJSON a) => HasServer (Patch a) where
         e <- runEitherT action
         respond . succeedWith $ case e of
           Right out -> case cast out of
-              Nothing -> responseLBS status200 [("Content-Type", "application/json")] (encode out) 
+              Nothing -> responseLBS status200 [("Content-Type", "application/json")] (encode out)
               Just () -> responseLBS status204 [] ""
           Left (status, message) ->
             responseLBS (mkStatus status (cs message)) [] (cs message)
diff --git a/test/Servant/Server/ContentTypesSpec.hs b/test/Servant/Server/ContentTypesSpec.hs
new file mode 100644
index 0000000..8d725f1
--- /dev/null
+++ b/test/Servant/Server/ContentTypesSpec.hs
@@ -0,0 +1,129 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE MultiParamTypeClasses #-}
+{-# LANGUAGE OverloadedStrings #-}
+{-# LANGUAGE TypeOperators #-}
+{-# OPTIONS_GHC -fno-warn-orphans #-}
+module Servant.Server.ContentTypesSpec where
+
+import Control.Applicative
+import Data.Aeson (encode)
+import Data.ByteString.Char8
+import Data.Function (on)
+import Data.Maybe (isJust, fromJust)
+import Data.List (maximumBy)
+import Data.Proxy (Proxy(..))
+import Data.String (IsString(..))
+import Data.String.Conversions (cs)
+import Network.HTTP.Types (hAccept)
+import Network.Wai (pathInfo, requestHeaders)
+import Network.Wai.Test ( runSession, request, defaultRequest
+                        , assertContentType, assertStatus )
+import Test.Hspec
+import Test.QuickCheck
+
+import Servant.API
+import Servant.Server
+import Servant.Server.ContentTypes
+
+
+spec :: Spec
+spec = describe "Servant.Server.ContentTypes" $ do
+    handleAcceptHSpec
+    contentTypeSpec
+
+handleAcceptHSpec :: Spec
+handleAcceptHSpec = describe "handleAcceptH" $ do
+
+    it "should return Just if the 'Accept' header matches" $ do
+        handleAcceptH (Proxy :: Proxy '[JSON]) "*/*" (3 :: Int)
+            `shouldSatisfy` isJust
+        handleAcceptH (Proxy :: Proxy '[XML, JSON]) "application/json" (3 :: Int)
+            `shouldSatisfy` isJust
+        handleAcceptH (Proxy :: Proxy '[XML, JSON, HTML]) "text/html" (3 :: Int)
+            `shouldSatisfy` isJust
+
+    it "should return the Content-Type as the first element of the tuple" $ do
+        handleAcceptH (Proxy :: Proxy '[JSON]) "*/*" (3 :: Int)
+            `shouldSatisfy` ((== "application/json") . fst . fromJust)
+        handleAcceptH (Proxy :: Proxy '[XML, JSON]) "application/json" (3 :: Int)
+            `shouldSatisfy` ((== "application/json") . fst . fromJust)
+        handleAcceptH (Proxy :: Proxy '[XML, JSON, HTML]) "text/html" (3 :: Int)
+            `shouldSatisfy` ((== "text/html") . fst . fromJust)
+
+    it "should return the appropriately serialized representation" $ do
+        property $ \x -> handleAcceptH (Proxy :: Proxy '[JSON]) "*/*" (x :: Int)
+            == Just ("application/json", encode x)
+
+    it "respects the Accept spec ordering" $
+        property $ \a b c i -> fst (fromJust $ val a b c i) == (fst $ highest a b c)
+          where
+            highest a b c = maximumBy (compare `on` snd) [ ("text/html", a)
+                                                      , ("application/json", b)
+                                                      , ("application/xml", c)
+                                                      ]
+            acceptH a b c = addToAccept (Proxy :: Proxy HTML) a $
+                            addToAccept (Proxy :: Proxy JSON) b $
+                            addToAccept (Proxy :: Proxy XML ) c ""
+            val a b c i = handleAcceptH (Proxy :: Proxy '[HTML, JSON, XML])
+                                        (acceptH a b c) (i :: Int)
+
+type ContentTypeApi = "foo" :> Get '[JSON] Int
+                 :<|> "bar" :> Get '[JSON, PlainText] Int
+
+contentTypeApi :: Proxy ContentTypeApi
+contentTypeApi = Proxy
+
+contentTypeServer :: Server ContentTypeApi
+contentTypeServer = return 5 :<|> return 3
+
+contentTypeSpec :: Spec
+contentTypeSpec = do
+    describe "Accept Headers" $ do
+
+        it "uses the highest quality possible in the header" $
+            flip runSession (serve contentTypeApi contentTypeServer) $ do
+                let acceptH = "text/plain; q=0.9, application/json; q=0.8"
+                response <- Network.Wai.Test.request defaultRequest{
+                    requestHeaders = [(hAccept, acceptH)] ,
+                    pathInfo = ["bar"]
+                }
+                assertContentType "text/plain" response
+
+        it "returns the first content-type if the Accept header is missing" $
+            flip runSession (serve contentTypeApi contentTypeServer) $ do
+                response <- Network.Wai.Test.request defaultRequest{
+                    pathInfo = ["bar"]
+                }
+                assertContentType "application/json" response
+
+        it "returns 406 if it can't serve the requested content-type" $
+            flip runSession (serve contentTypeApi contentTypeServer) $ do
+                let acceptH = "text/css"
+                response <- Network.Wai.Test.request defaultRequest{
+                    requestHeaders = [(hAccept, acceptH)] ,
+                    pathInfo = ["bar"]
+                }
+                assertStatus 406 response
+
+
+instance Show a => MimeRender HTML a where
+    toByteString _ = cs . show
+
+instance Show a => MimeRender XML a where
+    toByteString _ = cs . show
+
+instance IsString AcceptHeader where
+    fromString = AcceptHeader . fromString
+
+addToAccept :: Accept a => Proxy a -> ZeroToOne -> AcceptHeader -> AcceptHeader
+addToAccept p (ZeroToOne f) (AcceptHeader h) = AcceptHeader (cont h)
+    where new = cs (show $ contentType p) `append` "; q=" `append` pack (show f)
+          cont "" = new
+          cont old = old `append` ", " `append` new
+
+newtype ZeroToOne = ZeroToOne Float
+    deriving (Eq, Show, Ord)
+
+instance Arbitrary ZeroToOne where
+    arbitrary = ZeroToOne <$> elements [ x / 10 | x <- [1..10]]
diff --git a/test/Servant/ServerSpec.hs b/test/Servant/ServerSpec.hs
index ee3a8d2..c173c3a 100644
--- a/test/Servant/ServerSpec.hs
+++ b/test/Servant/ServerSpec.hs
@@ -22,6 +22,7 @@ import Network.Wai.Test (runSession, request, defaultRequest, simpleBody)
 import Test.Hspec (Spec, describe, it, shouldBe)
 import Test.Hspec.Wai (liftIO, with, get, post, shouldRespondWith, matchStatus)
 
+import Servant.API (JSON)
 import Servant.API.Capture (Capture)
 import Servant.API.Get (Get)
 import Servant.API.ReqBody (ReqBody)
@@ -79,7 +80,7 @@ spec = do
   errorsSpec
 
 
-type CaptureApi = Capture "legs" Integer :> Get Animal
+type CaptureApi = Capture "legs" Integer :> Get '[JSON] Animal
 captureApi :: Proxy CaptureApi
 captureApi = Proxy
 captureServer :: Integer -> EitherT (Int, String) IO Animal
@@ -105,7 +106,7 @@ captureSpec = do
         get "/captured/foo" `shouldRespondWith` (fromString (show ["foo" :: String]))
 
 
-type GetApi = Get Person
+type GetApi = Get '[JSON] Person
 getApi :: Proxy GetApi
 getApi = Proxy
 
@@ -123,9 +124,9 @@ getSpec = do
         post "/" "" `shouldRespondWith` 405
 
 
-type QueryParamApi = QueryParam "name" String :> Get Person
-                :<|> "a" :> QueryParams "names" String :> Get Person
-                :<|> "b" :> QueryFlag "capitalize" :> Get Person
+type QueryParamApi = QueryParam "name" String :> Get '[JSON] Person
+                :<|> "a" :> QueryParams "names" String :> Get '[JSON] Person
+                :<|> "b" :> QueryFlag "capitalize" :> Get '[JSON] Person
 
 queryParamApi :: Proxy QueryParamApi
 queryParamApi = Proxy
@@ -289,8 +290,8 @@ matrixParamSpec = do
              }
 
 type PostApi =
-       ReqBody Person :> Post Integer
-  :<|> "bla" :> ReqBody Person :> Post Integer
+       ReqBody Person :> Post '[JSON] Integer
+  :<|> "bla" :> ReqBody Person :> Post '[JSON] Integer
 postApi :: Proxy PostApi
 postApi = Proxy
 
@@ -344,8 +345,8 @@ rawSpec = do
 
 
 type AlternativeApi =
-       "foo" :> Get Person
-  :<|> "bar" :> Get Animal
+       "foo" :> Get '[JSON] Person
+  :<|> "bar" :> Get '[JSON] Animal
 unionApi :: Proxy AlternativeApi
 unionApi = Proxy
 
diff --git a/test/Servant/Utils/StaticFilesSpec.hs b/test/Servant/Utils/StaticFilesSpec.hs
index 6918448..4d4b242 100644
--- a/test/Servant/Utils/StaticFilesSpec.hs
+++ b/test/Servant/Utils/StaticFilesSpec.hs
@@ -13,6 +13,7 @@ import System.IO.Temp (withSystemTempDirectory)
 import Test.Hspec (Spec, describe, it, around_)
 import Test.Hspec.Wai (with, get, shouldRespondWith)
 
+import Servant.API (JSON)
 import Servant.API.Alternative ((:<|>)((:<|>)))
 import Servant.API.Capture (Capture)
 import Servant.API.Get (Get)
@@ -23,7 +24,7 @@ import Servant.ServerSpec (Person(Person))
 import Servant.Utils.StaticFiles (serveDirectory)
 
 type Api =
-       "dummy_api" :> Capture "person_name" String :> Get Person
+       "dummy_api" :> Capture "person_name" String :> Get '[JSON] Person
   :<|> "static" :> Raw