{-# LANGUAGE FlexibleContexts #-}
-- | Support for basic access authentication <http://en.wikipedia.org/wiki/Basic_access_authentication>
module Happstack.Server.Auth where

import Data.Foldable (foldl')
import Data.Bits (xor, (.|.))
import Data.Maybe (fromMaybe)
import Control.Monad                             (MonadPlus(mzero, mplus))
import Data.ByteString.Base64                    as Base64
import qualified Data.ByteString                 as BS
import qualified Data.ByteString.Char8           as B
import qualified Data.Map                        as M
import Happstack.Server.Monads                   (Happstack, escape, getHeaderM, setHeaderM)
import Happstack.Server.Response                 (unauthorized, toResponse)

-- | A simple HTTP basic authentication guard.
--
-- If authentication fails, this part will call 'mzero'.
-- 
-- example:
--
-- > main = simpleHTTP nullConf $ 
-- >  msum [ basicAuth "127.0.0.1" (fromList [("happstack","rocks")]) $ ok "You are in the secret club"
-- >       , ok "You are not in the secret club." 
-- >       ]
-- 
basicAuth :: (Happstack m) =>
   String -- ^ the realm name
   -> M.Map String String -- ^ the username password map
   -> m a -- ^ the part to guard
   -> m a
basicAuth :: String -> Map String String -> m a -> m a
basicAuth String
realmName Map String String
authMap = (ByteString -> ByteString -> Bool) -> String -> m a -> m a
forall (m :: * -> *) a.
Happstack m =>
(ByteString -> ByteString -> Bool) -> String -> m a -> m a
basicAuthBy (Map String String -> ByteString -> ByteString -> Bool
validLoginPlaintext Map String String
authMap) String
realmName


-- | Generalized version of 'basicAuth'.
--
-- The function that checks the username password combination must be
-- supplied as first argument.
--
-- example:
--
-- > main = simpleHTTP nullConf $
-- >  msum [ basicAuth' (validLoginPlaintext (fromList [("happstack","rocks")])) "127.0.0.1" $ ok "You are in the secret club"
-- >       , ok "You are not in the secret club."
-- >       ]
--
basicAuthBy :: (Happstack m) =>
   (B.ByteString -> B.ByteString -> Bool) -- ^ function that returns true if the name password combination is valid
   -> String -- ^ the realm name
   -> m a -- ^ the part to guard
   -> m a
basicAuthBy :: (ByteString -> ByteString -> Bool) -> String -> m a -> m a
basicAuthBy ByteString -> ByteString -> Bool
validLogin String
realmName m a
xs = m a
forall b. m b
basicAuthImpl m a -> m a -> m a
forall (m :: * -> *) a. MonadPlus m => m a -> m a -> m a
`mplus` m a
xs
  where
    basicAuthImpl :: m b
basicAuthImpl = do
        Maybe ByteString
aHeader <- String -> m (Maybe ByteString)
forall (m :: * -> *).
ServerMonad m =>
String -> m (Maybe ByteString)
getHeaderM String
"authorization"
        case Maybe ByteString
aHeader of
            Maybe ByteString
Nothing -> m b
forall (m :: * -> *) a. Happstack m => m a
err
            Just ByteString
x ->
                do (ByteString
name, ByteString
password) <- ByteString -> m (ByteString, ByteString)
forall (m :: * -> *).
Happstack m =>
ByteString -> m (ByteString, ByteString)
parseHeader ByteString
x
                   if ByteString -> Int
B.length ByteString
password Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0
                      Bool -> Bool -> Bool
&& ByteString -> Char
B.head ByteString
password Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
':'
                      Bool -> Bool -> Bool
&& ByteString -> ByteString -> Bool
validLogin ByteString
name (ByteString -> ByteString
B.tail ByteString
password)
                     then m b
forall (m :: * -> *) a. MonadPlus m => m a
mzero
                     else m b
forall (m :: * -> *) a. Happstack m => m a
err
    parseHeader :: ByteString -> m (ByteString, ByteString)
parseHeader ByteString
h =
      case ByteString -> Either String ByteString
Base64.decode (ByteString -> Either String ByteString)
-> (ByteString -> ByteString)
-> ByteString
-> Either String ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> ByteString -> ByteString
B.drop Int
6 (ByteString -> Either String ByteString)
-> ByteString -> Either String ByteString
forall a b. (a -> b) -> a -> b
$ ByteString
h of
        (Left String
_)   -> m (ByteString, ByteString)
forall (m :: * -> *) a. Happstack m => m a
err
        (Right ByteString
bs) -> (ByteString, ByteString) -> m (ByteString, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return ((Char -> Bool) -> ByteString -> (ByteString, ByteString)
B.break (Char
':'Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
==) ByteString
bs)
    headerName :: String
headerName  = String
"WWW-Authenticate"
    headerValue :: String
headerValue = String
"Basic realm=\"" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
realmName String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"\""
    err :: (Happstack m) => m a
    err :: m a
err = m Response -> m a
forall a (m :: * -> *) b.
(WebMonad a m, FilterMonad a m) =>
m a -> m b
escape (m Response -> m a) -> m Response -> m a
forall a b. (a -> b) -> a -> b
$ do
            String -> String -> m ()
forall (m :: * -> *).
FilterMonad Response m =>
String -> String -> m ()
setHeaderM String
headerName String
headerValue
            Response -> m Response
forall (m :: * -> *) a. FilterMonad Response m => a -> m a
unauthorized (Response -> m Response) -> Response -> m Response
forall a b. (a -> b) -> a -> b
$ String -> Response
forall a. ToMessage a => a -> Response
toResponse String
"Not authorized"


-- | Function that looks up the plain text password for username in a
-- Map and returns True if it matches with the given password.
--
-- Note: The implementation is hardened against timing attacks but not
-- completely safe. Ideally you should build your own predicate, using
-- a robust constant-time equality comparison from a cryptographic
-- library like sodium.
validLoginPlaintext ::
  M.Map String String -- ^ the username password map
  -> B.ByteString -- ^ the username
  -> B.ByteString -- ^ the password
  -> Bool
validLoginPlaintext :: Map String String -> ByteString -> ByteString -> Bool
validLoginPlaintext Map String String
authMap ByteString
name ByteString
password = Bool -> Maybe Bool -> Bool
forall a. a -> Maybe a -> a
fromMaybe Bool
False (Maybe Bool -> Bool) -> Maybe Bool -> Bool
forall a b. (a -> b) -> a -> b
$ do
    String
r <- String -> Map String String -> Maybe String
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (ByteString -> String
B.unpack ByteString
name) Map String String
authMap
    Bool -> Maybe Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> ByteString -> Bool
constTimeEq (String -> ByteString
B.pack String
r) ByteString
password)
  where
    -- (Mostly) constant time equality of bytestrings to prevent timing attacks by testing out passwords. This still
    -- allows to extract the length of the configured password via timing attacks. This implementation is still brittle
    -- in the sense that it relies on GHC not unrolling or vectorizing the loop.
    {-# NOINLINE constTimeEq #-}
    constTimeEq :: BS.ByteString -> BS.ByteString -> Bool
    constTimeEq :: ByteString -> ByteString -> Bool
constTimeEq ByteString
x ByteString
y
      | ByteString -> Int
BS.length ByteString
x Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= ByteString -> Int
BS.length ByteString
y
      = Bool
False

      | Bool
otherwise
      = (Word8 -> Word8 -> Word8) -> Word8 -> [Word8] -> Word8
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
(.|.) Word8
0 ((Word8 -> Word8 -> Word8) -> ByteString -> ByteString -> [Word8]
forall a. (Word8 -> Word8 -> a) -> ByteString -> ByteString -> [a]
BS.zipWith Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor ByteString
x ByteString
y) Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0