{-# LANGUAGE FlexibleContexts #-}
module Happstack.Server.Auth where
import Data.Foldable (foldl')
import Data.Bits (xor, (.|.))
import Data.Maybe (fromMaybe)
import Control.Monad (MonadPlus(mzero, mplus))
import Data.ByteString.Base64 as Base64
import qualified Data.ByteString as BS
import qualified Data.ByteString.Char8 as B
import qualified Data.Map as M
import Happstack.Server.Monads (Happstack, escape, getHeaderM, setHeaderM)
import Happstack.Server.Response (unauthorized, toResponse)
basicAuth :: (Happstack m) =>
String
-> M.Map String String
-> m a
-> m a
basicAuth :: String -> Map String String -> m a -> m a
basicAuth String
realmName Map String String
authMap = (ByteString -> ByteString -> Bool) -> String -> m a -> m a
forall (m :: * -> *) a.
Happstack m =>
(ByteString -> ByteString -> Bool) -> String -> m a -> m a
basicAuthBy (Map String String -> ByteString -> ByteString -> Bool
validLoginPlaintext Map String String
authMap) String
realmName
basicAuthBy :: (Happstack m) =>
(B.ByteString -> B.ByteString -> Bool)
-> String
-> m a
-> m a
basicAuthBy :: (ByteString -> ByteString -> Bool) -> String -> m a -> m a
basicAuthBy ByteString -> ByteString -> Bool
validLogin String
realmName m a
xs = m a
forall b. m b
basicAuthImpl m a -> m a -> m a
forall (m :: * -> *) a. MonadPlus m => m a -> m a -> m a
`mplus` m a
xs
where
basicAuthImpl :: m b
basicAuthImpl = do
Maybe ByteString
aHeader <- String -> m (Maybe ByteString)
forall (m :: * -> *).
ServerMonad m =>
String -> m (Maybe ByteString)
getHeaderM String
"authorization"
case Maybe ByteString
aHeader of
Maybe ByteString
Nothing -> m b
forall (m :: * -> *) a. Happstack m => m a
err
Just ByteString
x ->
do (ByteString
name, ByteString
password) <- ByteString -> m (ByteString, ByteString)
forall (m :: * -> *).
Happstack m =>
ByteString -> m (ByteString, ByteString)
parseHeader ByteString
x
if ByteString -> Int
B.length ByteString
password Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0
Bool -> Bool -> Bool
&& ByteString -> Char
B.head ByteString
password Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
':'
Bool -> Bool -> Bool
&& ByteString -> ByteString -> Bool
validLogin ByteString
name (ByteString -> ByteString
B.tail ByteString
password)
then m b
forall (m :: * -> *) a. MonadPlus m => m a
mzero
else m b
forall (m :: * -> *) a. Happstack m => m a
err
parseHeader :: ByteString -> m (ByteString, ByteString)
parseHeader ByteString
h =
case ByteString -> Either String ByteString
Base64.decode (ByteString -> Either String ByteString)
-> (ByteString -> ByteString)
-> ByteString
-> Either String ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> ByteString -> ByteString
B.drop Int
6 (ByteString -> Either String ByteString)
-> ByteString -> Either String ByteString
forall a b. (a -> b) -> a -> b
$ ByteString
h of
(Left String
_) -> m (ByteString, ByteString)
forall (m :: * -> *) a. Happstack m => m a
err
(Right ByteString
bs) -> (ByteString, ByteString) -> m (ByteString, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return ((Char -> Bool) -> ByteString -> (ByteString, ByteString)
B.break (Char
':'Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
==) ByteString
bs)
headerName :: String
headerName = String
"WWW-Authenticate"
headerValue :: String
headerValue = String
"Basic realm=\"" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
realmName String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"\""
err :: (Happstack m) => m a
err :: m a
err = m Response -> m a
forall a (m :: * -> *) b.
(WebMonad a m, FilterMonad a m) =>
m a -> m b
escape (m Response -> m a) -> m Response -> m a
forall a b. (a -> b) -> a -> b
$ do
String -> String -> m ()
forall (m :: * -> *).
FilterMonad Response m =>
String -> String -> m ()
setHeaderM String
headerName String
headerValue
Response -> m Response
forall (m :: * -> *) a. FilterMonad Response m => a -> m a
unauthorized (Response -> m Response) -> Response -> m Response
forall a b. (a -> b) -> a -> b
$ String -> Response
forall a. ToMessage a => a -> Response
toResponse String
"Not authorized"
validLoginPlaintext ::
M.Map String String
-> B.ByteString
-> B.ByteString
-> Bool
validLoginPlaintext :: Map String String -> ByteString -> ByteString -> Bool
validLoginPlaintext Map String String
authMap ByteString
name ByteString
password = Bool -> Maybe Bool -> Bool
forall a. a -> Maybe a -> a
fromMaybe Bool
False (Maybe Bool -> Bool) -> Maybe Bool -> Bool
forall a b. (a -> b) -> a -> b
$ do
String
r <- String -> Map String String -> Maybe String
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (ByteString -> String
B.unpack ByteString
name) Map String String
authMap
Bool -> Maybe Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> ByteString -> Bool
constTimeEq (String -> ByteString
B.pack String
r) ByteString
password)
where
{-# NOINLINE constTimeEq #-}
constTimeEq :: BS.ByteString -> BS.ByteString -> Bool
constTimeEq :: ByteString -> ByteString -> Bool
constTimeEq ByteString
x ByteString
y
| ByteString -> Int
BS.length ByteString
x Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= ByteString -> Int
BS.length ByteString
y
= Bool
False
| Bool
otherwise
= (Word8 -> Word8 -> Word8) -> Word8 -> [Word8] -> Word8
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
(.|.) Word8
0 ((Word8 -> Word8 -> Word8) -> ByteString -> ByteString -> [Word8]
forall a. (Word8 -> Word8 -> a) -> ByteString -> ByteString -> [a]
BS.zipWith Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor ByteString
x ByteString
y) Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0