module Servant.Server.Experimental.Auth where
import Control.Monad.Trans (liftIO)
import Data.Proxy (Proxy (Proxy))
import Data.Typeable (Typeable)
import GHC.Generics (Generic)
import Network.Wai (Request)
import Servant ((:>))
import Servant.API.Experimental.Auth
import Servant.Server.Internal (HasContextEntry,
HasServer (..),
getContextEntry)
import Servant.Server.Internal.RoutingApplication (addAuthCheck,
delayedFailFatal,
DelayedIO,
withRequest)
import Servant.Server.Internal.Handler (Handler, runHandler)
type family AuthServerData a :: *
newtype AuthHandler r usr = AuthHandler
{ unAuthHandler :: r -> Handler usr }
deriving (Generic, Typeable)
mkAuthHandler :: (r -> Handler usr) -> AuthHandler r usr
mkAuthHandler = AuthHandler
instance ( HasServer api context
, HasContextEntry context (AuthHandler Request (AuthServerData (AuthProtect tag)))
)
=> HasServer (AuthProtect tag :> api) context where
type ServerT (AuthProtect tag :> api) m =
AuthServerData (AuthProtect tag) -> ServerT api m
hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy :: Proxy api) pc nt . s
route Proxy context subserver =
route (Proxy :: Proxy api) context (subserver `addAuthCheck` withRequest authCheck)
where
authHandler :: Request -> Handler (AuthServerData (AuthProtect tag))
authHandler = unAuthHandler (getContextEntry context)
authCheck :: Request -> DelayedIO (AuthServerData (AuthProtect tag))
authCheck = (>>= either delayedFailFatal return) . liftIO . runHandler . authHandler