{-# LANGUAGE DeriveFunctor #-}

module Web.Minion.Auth where

import Data.Kind (Type)
import Data.Void (Void, absurd)
import Network.Wai qualified as Wai
import Web.Minion.Args (GetByType (getByType), HList, WithReq)
import Web.Minion.Error
import Web.Minion.Introspect qualified as I
import Web.Minion.Request
import Web.Minion.Router

newtype Auth (auths :: [Type]) a = Auth a

instance IsRequest (Auth auths a) where
  type RequestValue (Auth auths a) = a
  getRequestValue :: Auth auths a -> RequestValue (Auth auths a)
getRequestValue (Auth a
a) = a
RequestValue (Auth auths a)
a

data AuthResult a
  = Indefinite
  | BadAuth
  | Authenticated a
  deriving ((forall a b. (a -> b) -> AuthResult a -> AuthResult b)
-> (forall a b. a -> AuthResult b -> AuthResult a)
-> Functor AuthResult
forall a b. a -> AuthResult b -> AuthResult a
forall a b. (a -> b) -> AuthResult a -> AuthResult b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> AuthResult a -> AuthResult b
fmap :: forall a b. (a -> b) -> AuthResult a -> AuthResult b
$c<$ :: forall a b. a -> AuthResult b -> AuthResult a
<$ :: forall a b. a -> AuthResult b -> AuthResult a
Functor)

class UnwindAuth (ctx :: [Type]) (auths :: [Type]) m a where
  unwindAuth :: [HList ctx -> ErrorBuilder -> Wai.Request -> m (AuthResult a)]

class IsAuth (auth :: Type) m a where
  type Settings auth m a :: Type
  toAuth :: Settings auth m a -> ErrorBuilder -> Wai.Request -> m (AuthResult a)

instance
  ( IsAuth auth m a
  , UnwindAuth ctx auths m a
  , GetByType (Settings auth m a) ctx
  ) =>
  UnwindAuth ctx (auth ': auths) m a
  where
  {-# INLINE unwindAuth #-}
  unwindAuth :: [HList ctx -> ErrorBuilder -> Request -> m (AuthResult a)]
unwindAuth = (forall auth (m :: * -> *) a.
IsAuth auth m a =>
Settings auth m a -> ErrorBuilder -> Request -> m (AuthResult a)
toAuth @auth (Settings auth m a -> ErrorBuilder -> Request -> m (AuthResult a))
-> (HList ctx -> Settings auth m a)
-> HList ctx
-> ErrorBuilder
-> Request
-> m (AuthResult a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HList ctx -> Settings auth m a
forall t (ts :: [*]). GetByType t ts => HList ts -> t
getByType) (HList ctx -> ErrorBuilder -> Request -> m (AuthResult a))
-> [HList ctx -> ErrorBuilder -> Request -> m (AuthResult a)]
-> [HList ctx -> ErrorBuilder -> Request -> m (AuthResult a)]
forall a. a -> [a] -> [a]
: (forall (ctx :: [*]) (auths :: [*]) (m :: * -> *) a.
UnwindAuth ctx auths m a =>
[HList ctx -> ErrorBuilder -> Request -> m (AuthResult a)]
unwindAuth @ctx @auths)

instance UnwindAuth ctx '[] m a where
  {-# INLINE unwindAuth #-}
  unwindAuth :: [HList ctx -> ErrorBuilder -> Request -> m (AuthResult a)]
unwindAuth = []

{-# INLINE auth #-}
auth ::
  forall auths a m ctx ts i.
  (I.Introspection i I.Request (Auth auths a)) =>
  (UnwindAuth ctx auths m a) =>
  (MonadThrow m) =>
  -- | Context with auths settings
  m (HList ctx) ->
  -- |  Handle non-Authenticated.
  (MakeError -> AuthResult Void -> m Void) ->
  ValueCombinator i (WithReq m (Auth auths a)) ts m
auth :: forall (auths :: [*]) a (m :: * -> *) (ctx :: [*]) ts i.
(Introspection i 'Request (Auth auths a), UnwindAuth ctx auths m a,
 MonadThrow m) =>
m (HList ctx)
-> (MakeError -> AuthResult Void -> m Void)
-> ValueCombinator i (WithReq m (Auth auths a)) ts m
auth m (HList ctx)
ctxm MakeError -> AuthResult Void -> m Void
cont = (ErrorBuilder -> Request -> m (Auth auths a))
-> Router' i (ts :+ WithReq m (Auth auths a)) m -> Router' i ts m
forall r (m :: * -> *) i ts.
(Introspection i 'Request r, IsRequest r) =>
(ErrorBuilder -> Request -> m r)
-> Router' i (ts :+ WithReq m r) m -> Router' i ts m
Request \ErrorBuilder
errorBuilder Request
req -> do
  HList ctx
ctx <- m (HList ctx)
ctxm
  let auths :: [HList ctx -> ErrorBuilder -> Request -> m (AuthResult a)]
auths = forall (ctx :: [*]) (auths :: [*]) (m :: * -> *) a.
UnwindAuth ctx auths m a =>
[HList ctx -> ErrorBuilder -> Request -> m (AuthResult a)]
unwindAuth @ctx @auths @m @a
      {-# INLINE go #-}
      go :: [HList ctx -> ErrorBuilder -> Request -> m (AuthResult a)]
-> m (AuthResult a)
go [] = AuthResult a -> m (AuthResult a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure AuthResult a
forall a. AuthResult a
Indefinite
      go (HList ctx -> ErrorBuilder -> Request -> m (AuthResult a)
a : [HList ctx -> ErrorBuilder -> Request -> m (AuthResult a)]
as) =
        HList ctx -> ErrorBuilder -> Request -> m (AuthResult a)
a HList ctx
ctx ErrorBuilder
errorBuilder Request
req m (AuthResult a)
-> (AuthResult a -> m (AuthResult a)) -> m (AuthResult a)
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
          AuthResult a
Indefinite -> [HList ctx -> ErrorBuilder -> Request -> m (AuthResult a)]
-> m (AuthResult a)
go [HList ctx -> ErrorBuilder -> Request -> m (AuthResult a)]
as
          AuthResult a
r -> AuthResult a -> m (AuthResult a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure AuthResult a
r
  [HList ctx -> ErrorBuilder -> Request -> m (AuthResult a)]
-> m (AuthResult a)
go [HList ctx -> ErrorBuilder -> Request -> m (AuthResult a)]
auths
    m (AuthResult a)
-> (AuthResult a -> m (Auth auths a)) -> m (Auth auths a)
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (a -> Auth auths a) -> m a -> m (Auth auths a)
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> Auth auths a
forall (auths :: [*]) a. a -> Auth auths a
Auth (m a -> m (Auth auths a))
-> (AuthResult a -> m a) -> AuthResult a -> m (Auth auths a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. \case
      Authenticated a
a -> a -> m a
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
a
      AuthResult a
BadAuth -> Void -> a
forall a. Void -> a
absurd (Void -> a) -> m Void -> m a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MakeError -> AuthResult Void -> m Void
cont (ErrorBuilder
errorBuilder Request
req) AuthResult Void
forall a. AuthResult a
BadAuth
      AuthResult a
Indefinite -> Void -> a
forall a. Void -> a
absurd (Void -> a) -> m Void -> m a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MakeError -> AuthResult Void -> m Void
cont (ErrorBuilder
errorBuilder Request
req) AuthResult Void
forall a. AuthResult a
BadAuth