{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE OverloadedStrings #-}
#if __GLASGOW_HASKELL__ < 710
{-# LANGUAGE OverlappingInstances #-}
#endif
{-# LANGUAGE ScopedTypeVariables #-}
module Servant.Server.Internal.BasicAuth where
#if __GLASGOW_HASKELL__ < 710
import Data.Functor ((<$>))
#endif
import Control.Monad (guard)
import qualified Data.ByteString as BS
import Data.ByteString.Base64 (decodeLenient)
import Data.CaseInsensitive (CI(..))
import Data.Monoid ((<>))
import Data.Typeable (Typeable)
import Data.Word8 (isSpace, toLower, _colon)
import GHC.Generics
import Snap.Core
import Servant.API.BasicAuth (BasicAuthData(BasicAuthData))
import Servant.Server.Internal.RoutingApplication
import Servant.Server.Internal.ServantErr
data BasicAuthResult usr
= Unauthorized
| BadPassword
| NoSuchUser
| Authorized usr
deriving (Eq, Show, Read, Generic, Typeable, Functor)
newtype BasicAuthCheck m usr = BasicAuthCheck
{ unBasicAuthCheck :: BasicAuthData
-> m (BasicAuthResult usr)
}
deriving (Generic, Typeable, Functor)
mkBAChallengerHdr :: BS.ByteString -> (CI BS.ByteString, BS.ByteString)
mkBAChallengerHdr realm = ("WWW-Authenticate", "Basic realm=\"" <> realm <> "\"")
decodeBAHdr :: Request -> Maybe BasicAuthData
decodeBAHdr req = do
ah <- getHeader "Authorization" req
let (b, rest) = BS.break isSpace ah
guard (BS.map toLower b == "basic")
let decoded = decodeLenient (BS.dropWhile isSpace rest)
let (username, passWithColonAtHead) = BS.break (== _colon) decoded
(_, password) <- BS.uncons passWithColonAtHead
return (BasicAuthData username password)
runBasicAuth :: MonadSnap m => Request -> BS.ByteString -> BasicAuthCheck m usr -> DelayedM m usr
runBasicAuth req realm (BasicAuthCheck ba) =
case decodeBAHdr req of
Nothing -> plzAuthenticate
Just e -> DelayedM (const $ Route <$> ba e) >>= \res -> case res of
BadPassword -> plzAuthenticate
NoSuchUser -> plzAuthenticate
Unauthorized -> delayedFailFatal err403
Authorized usr -> return usr
where plzAuthenticate = delayedFailFatal err401 { errHeaders = [mkBAChallengerHdr realm] }