{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
module Control.SafeAccess
( Capability(..)
, Capabilities
, AccessDecision(..)
, SafeAccessT(..)
, MonadSafeAccess(..)
, ensureAccess
, unsecureAllow
, singleCapability
, someCapabilities
, passthroughCapability
, liftExceptT
, liftCapability
) where
import Control.Monad
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Trans.Cont (ContT(..))
import Control.Monad.Trans.Identity
import Control.Monad.Trans.List
import Control.Monad.Trans.Maybe
import Control.Monad.Trans.RWS (RWST(..))
import Control.Monad.Writer
newtype Capability m d = MkCapability { runCapability :: d -> m AccessDecision }
instance Applicative m => Semigroup (Capability m d) where
(<>) cap1 cap2 = MkCapability $ \d ->
(<>) <$> runCapability cap1 d <*> runCapability cap2 d
instance Applicative m => Monoid (Capability m d) where
mempty = MkCapability $ \_ -> pure mempty
mappend = (<>)
type Capabilities m d = [Capability m d]
singleCapability :: (Applicative f, Eq d) => d -> Capability f d
singleCapability = someCapabilities . pure
someCapabilities :: (Applicative f, Eq d) => [d] -> Capability f d
someCapabilities ds = MkCapability $ \d ->
pure $ if d `elem` ds then AccessGranted else AccessDeniedSoft
passthroughCapability :: Applicative f => Capability f d
passthroughCapability = MkCapability $ \_ -> pure AccessGranted
data AccessDecision
= AccessDeniedSoft
| AccessGranted
| AccessDenied
deriving (Show, Eq)
instance Semigroup AccessDecision where
a <> b = case (a, b) of
(AccessDeniedSoft, _) -> b
(_, AccessDeniedSoft) -> a
(AccessDenied, _) -> AccessDenied
(_, AccessDenied) -> AccessDenied
(AccessGranted, AccessGranted) -> AccessGranted
instance Monoid AccessDecision where
mempty = AccessDeniedSoft
mappend = (<>)
newtype SafeAccessT d m a
= SafeAccessT { runSafeAccessT :: Capabilities m d -> m (Either d a) }
instance Functor f => Functor (SafeAccessT d f) where
fmap f sa = SafeAccessT $ \caps -> fmap (fmap f) $ runSafeAccessT sa caps
instance Applicative f => Applicative (SafeAccessT d f) where
pure = SafeAccessT . const . pure . Right
safef <*> safea = SafeAccessT $ \caps ->
let fef = runSafeAccessT safef caps
fea = runSafeAccessT safea caps
ff = flip fmap fef $ \ef -> case ef of
Left d -> const $ Left d
Right f -> fmap f
in ff <*> fea
instance Monad m => Monad (SafeAccessT d m) where
return = SafeAccessT . const . return . Right
ma >>= f = SafeAccessT $ \caps -> do
ex <- runSafeAccessT ma caps
case ex of
Left d -> return $ Left d
Right x -> runSafeAccessT (f x) caps
instance MonadTrans (SafeAccessT d) where
lift = SafeAccessT . const . (Right `liftM`)
instance MonadIO m => MonadIO (SafeAccessT d m) where
liftIO = SafeAccessT . const . (Right `liftM`) . liftIO
ensureAccess :: MonadSafeAccess d m s => d -> m ()
ensureAccess descr = do
caps <- getCapabilities
decisions <- liftSub $ mapM (\cap -> runCapability cap descr) caps
case mconcat decisions of
AccessGranted -> return ()
_ -> denyAccess descr
unsecureAllow :: (Monad m, Eq d) => [d] -> SafeAccessT d m a
-> SafeAccessT d m a
unsecureAllow descrs action = do
caps <- getCapabilities
let cap = MkCapability $ \descr ->
if descr `elem` descrs
then pure AccessGranted
else mconcat <$> sequenceA (map (flip runCapability descr) caps)
SafeAccessT $ \_ -> runSafeAccessT action [cap]
getCapabilities' :: Monad m => SafeAccessT d m (Capabilities m d)
getCapabilities' = SafeAccessT $ return . Right
denyAccess' :: Monad m => d -> SafeAccessT d m ()
denyAccess' = SafeAccessT . const . return . Left
catchAccessError' :: Monad m => SafeAccessT d m a -> (d -> SafeAccessT d m a)
-> SafeAccessT d m a
catchAccessError' action handler = SafeAccessT $ \caps -> do
eRes <- runSafeAccessT action caps
case eRes of
Left descr -> runSafeAccessT (handler descr) caps
Right _ -> return eRes
class (Monad m, Monad s) => MonadSafeAccess d m s | m -> s, m -> d where
getCapabilities :: m (Capabilities s d)
liftSub :: s a -> m a
denyAccess :: d -> m ()
catchAccessError :: m a -> (d -> m a) -> m a
instance Monad m => MonadSafeAccess d (SafeAccessT d m) m where
getCapabilities = getCapabilities'
liftSub = lift
denyAccess = denyAccess'
catchAccessError = catchAccessError'
liftExceptT :: ExceptT d m a -> SafeAccessT d m a
liftExceptT = SafeAccessT . const . runExceptT
instance MonadError e m => MonadError e (SafeAccessT d m) where
throwError = lift . throwError
catchError a h = SafeAccessT $ \caps ->
catchError (runSafeAccessT a caps) (\e -> runSafeAccessT (h e) caps)
instance MonadReader r m => MonadReader r (SafeAccessT d m) where
ask = lift ask
local f a = SafeAccessT $ local f . runSafeAccessT a
instance MonadState s m => MonadState s (SafeAccessT d m) where
state = lift . state
instance MonadWriter w m => MonadWriter w (SafeAccessT d m) where
writer = lift . writer
listen sma = SafeAccessT $ \caps -> do
let mea = runSafeAccessT sma caps
pair <- listen mea
return $ case pair of
(Left d, _) -> Left d
(Right x, w) -> Right (x, w)
pass smp = SafeAccessT $ \caps -> do
let mep = runSafeAccessT smp caps
mpe = flip liftM mep $ \ep -> case ep of
Left d -> (Left d, id)
Right (x, f) -> (Right x, f)
pass mpe
liftCapability :: (Monad m, MonadTrans t)
=> Capability m d -> Capability (t m) d
liftCapability cap = MkCapability $ \descr -> lift $ runCapability cap descr
instance MonadSafeAccess d m s => MonadSafeAccess d (ContT e m) s where
getCapabilities = lift getCapabilities
liftSub = lift . liftSub
denyAccess = lift . denyAccess
catchAccessError action handler = ContT $ \cont ->
catchAccessError (runContT action cont)
(\descr -> runContT (handler descr) cont)
instance MonadSafeAccess d m s => MonadSafeAccess d (ExceptT e m) s where
getCapabilities = lift getCapabilities
liftSub = lift . liftSub
denyAccess = lift . denyAccess
catchAccessError action handler = ExceptT $
catchAccessError (runExceptT action) (runExceptT . handler)
instance MonadSafeAccess d m s => MonadSafeAccess d (StateT s' m) s where
getCapabilities = lift getCapabilities
liftSub = lift . liftSub
denyAccess = lift . denyAccess
catchAccessError action handler = StateT $ \s ->
catchAccessError (runStateT action s)
(\descr -> runStateT (handler descr) s)
instance MonadSafeAccess d m s => MonadSafeAccess d (IdentityT m) s where
getCapabilities = lift getCapabilities
liftSub = lift . liftSub
denyAccess = lift . denyAccess
catchAccessError action handler = IdentityT $
catchAccessError (runIdentityT action) (runIdentityT . handler)
instance MonadSafeAccess d m s => MonadSafeAccess d (ListT m) s where
getCapabilities = lift getCapabilities
liftSub = lift . liftSub
denyAccess = lift . denyAccess
catchAccessError action handler = ListT $
catchAccessError (runListT action) (runListT . handler)
instance MonadSafeAccess d m s => MonadSafeAccess d (MaybeT m) s where
getCapabilities = lift getCapabilities
liftSub = lift . liftSub
denyAccess = lift . denyAccess
catchAccessError action handler = MaybeT $
catchAccessError (runMaybeT action) (runMaybeT . handler)
instance MonadSafeAccess d m s => MonadSafeAccess d (ReaderT r m) s where
getCapabilities = lift getCapabilities
liftSub = lift . liftSub
denyAccess = lift . denyAccess
catchAccessError action handler = ReaderT $ \r ->
catchAccessError (runReaderT action r)
(\descr -> runReaderT (handler descr) r)
instance (MonadSafeAccess d m s, Monoid w)
=> MonadSafeAccess d (RWST r w st m) s where
getCapabilities = lift getCapabilities
liftSub = lift . liftSub
denyAccess = lift . denyAccess
catchAccessError action handler = RWST $ \r s ->
catchAccessError (runRWST action r s)
(\descr -> runRWST (handler descr) r s)
instance (MonadSafeAccess d m s, Monoid w)
=> MonadSafeAccess d (WriterT w m) s where
getCapabilities = lift getCapabilities
liftSub = lift . liftSub
denyAccess = lift . denyAccess
catchAccessError action handler = WriterT $
catchAccessError (runWriterT action) (runWriterT . handler)