module Web.Minion.Auth.Jwt where

import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Trans.Maybe (MaybeT (..))
import Crypto.JOSE qualified as Jose
import Crypto.JWT (JWTError)
import Crypto.JWT qualified as Jose
import Data.Aeson (FromJSON (..))
import Data.ByteString qualified as Bytes
import Data.ByteString.Lazy qualified as Bytes.Lazy
import Data.Function ((&))
import Data.Functor ((<&>))
import Data.Time qualified as Time
import Network.HTTP.Types.Header qualified as Http
import Network.Wai qualified as Wai
import Web.Minion

data JwtAuthSettings m payload a = JwtAuthSettings
  { forall (m :: * -> *) payload a.
JwtAuthSettings m payload a -> m UTCTime
getNow :: m Time.UTCTime
  , forall (m :: * -> *) payload a.
JwtAuthSettings m payload a -> m JWK
jwk :: m Jose.JWK
  , forall (m :: * -> *) payload a.
JwtAuthSettings m payload a -> m JWTValidationSettings
validationSettings :: m Jose.JWTValidationSettings
  , forall (m :: * -> *) payload a.
JwtAuthSettings m payload a
-> MakeError
-> Either JWTError (JwtPayload payload)
-> m (AuthResult a)
check :: MakeError -> Either JWTError (JwtPayload payload) -> m (AuthResult a)
  }

defaultJwtAuthSettings ::
  (MonadIO m) =>
  m Jose.JWK ->
  -- | Audience predicate
  (Jose.StringOrURI -> Bool) ->
  (MakeError -> Either JWTError (JwtPayload payload) -> m (AuthResult a)) ->
  JwtAuthSettings m payload a
defaultJwtAuthSettings :: forall (m :: * -> *) payload a.
MonadIO m =>
m JWK
-> (StringOrURI -> Bool)
-> (MakeError
    -> Either JWTError (JwtPayload payload) -> m (AuthResult a))
-> JwtAuthSettings m payload a
defaultJwtAuthSettings m JWK
jwk StringOrURI -> Bool
audCheck MakeError
-> Either JWTError (JwtPayload payload) -> m (AuthResult a)
check =
  JwtAuthSettings
    { $sel:getNow:JwtAuthSettings :: m UTCTime
getNow = IO UTCTime -> m UTCTime
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO UTCTime
Time.getCurrentTime
    , $sel:jwk:JwtAuthSettings :: m JWK
jwk = m JWK
jwk
    , $sel:validationSettings:JwtAuthSettings :: m JWTValidationSettings
validationSettings = JWTValidationSettings -> m JWTValidationSettings
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((StringOrURI -> Bool) -> JWTValidationSettings
Jose.defaultJWTValidationSettings StringOrURI -> Bool
audCheck)
    , $sel:check:JwtAuthSettings :: MakeError
-> Either JWTError (JwtPayload payload) -> m (AuthResult a)
check = MakeError
-> Either JWTError (JwtPayload payload) -> m (AuthResult a)
check
    }

data JwtPayload a = JwtPayload
  { forall a. JwtPayload a -> ClaimsSet
claims :: Jose.ClaimsSet
  , forall a. JwtPayload a -> a
payload :: a
  }

instance Jose.HasClaimsSet (JwtPayload a) where
  claimsSet :: Lens' (JwtPayload a) ClaimsSet
claimsSet ClaimsSet -> f ClaimsSet
f JwtPayload{a
ClaimsSet
$sel:claims:JwtPayload :: forall a. JwtPayload a -> ClaimsSet
$sel:payload:JwtPayload :: forall a. JwtPayload a -> a
claims :: ClaimsSet
payload :: a
..} = ClaimsSet -> f ClaimsSet
f ClaimsSet
claims f ClaimsSet -> (ClaimsSet -> JwtPayload a) -> f (JwtPayload a)
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \ClaimsSet
c -> JwtPayload{$sel:claims:JwtPayload :: ClaimsSet
claims = ClaimsSet
c, a
$sel:payload:JwtPayload :: a
payload :: a
..}

instance (FromJSON a) => FromJSON (JwtPayload a) where
  parseJSON :: Value -> Parser (JwtPayload a)
parseJSON Value
v =
    ClaimsSet -> a -> JwtPayload a
forall a. ClaimsSet -> a -> JwtPayload a
JwtPayload
      (ClaimsSet -> a -> JwtPayload a)
-> Parser ClaimsSet -> Parser (a -> JwtPayload a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Value -> Parser ClaimsSet
forall a. FromJSON a => Value -> Parser a
parseJSON Value
v
      Parser (a -> JwtPayload a) -> Parser a -> Parser (JwtPayload a)
forall a b. Parser (a -> b) -> Parser a -> Parser b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Value -> Parser a
forall a. FromJSON a => Value -> Parser a
parseJSON Value
v

data Bearer payload

instance (MonadIO m, FromJSON payload) => IsAuth (Bearer payload) m a where
  type Settings (Bearer payload) m a = JwtAuthSettings m payload a
  toAuth :: Settings (Bearer payload) m a
-> ErrorBuilder -> Request -> m (AuthResult a)
toAuth JwtAuthSettings{m UTCTime
m JWK
m JWTValidationSettings
MakeError
-> Either JWTError (JwtPayload payload) -> m (AuthResult a)
$sel:getNow:JwtAuthSettings :: forall (m :: * -> *) payload a.
JwtAuthSettings m payload a -> m UTCTime
$sel:jwk:JwtAuthSettings :: forall (m :: * -> *) payload a.
JwtAuthSettings m payload a -> m JWK
$sel:validationSettings:JwtAuthSettings :: forall (m :: * -> *) payload a.
JwtAuthSettings m payload a -> m JWTValidationSettings
$sel:check:JwtAuthSettings :: forall (m :: * -> *) payload a.
JwtAuthSettings m payload a
-> MakeError
-> Either JWTError (JwtPayload payload)
-> m (AuthResult a)
getNow :: m UTCTime
jwk :: m JWK
validationSettings :: m JWTValidationSettings
check :: MakeError
-> Either JWTError (JwtPayload payload) -> m (AuthResult a)
..} ErrorBuilder
buildError Request
req = do
    JWK
jwk_ <- m JWK
jwk
    UTCTime
now <- m UTCTime
getNow
    JWTValidationSettings
settings <- m JWTValidationSettings
validationSettings
    Either JWTError (Maybe (JwtPayload payload))
payload <- JOSE JWTError m (Maybe (JwtPayload payload))
-> m (Either JWTError (Maybe (JwtPayload payload)))
forall e (m :: * -> *) a. JOSE e m a -> m (Either e a)
Jose.runJOSE (JOSE JWTError m (Maybe (JwtPayload payload))
 -> m (Either JWTError (Maybe (JwtPayload payload))))
-> JOSE JWTError m (Maybe (JwtPayload payload))
-> m (Either JWTError (Maybe (JwtPayload payload)))
forall a b. (a -> b) -> a -> b
$ MaybeT (JOSE JWTError m) (JwtPayload payload)
-> JOSE JWTError m (Maybe (JwtPayload payload))
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT do
      ByteString
authHeader <- Request -> RequestHeaders
Wai.requestHeaders Request
req RequestHeaders
-> (RequestHeaders -> Maybe ByteString) -> Maybe ByteString
forall a b. a -> (a -> b) -> b
& HeaderName -> RequestHeaders -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
Http.hAuthorization Maybe ByteString
-> (Maybe ByteString -> MaybeT (JOSE JWTError m) ByteString)
-> MaybeT (JOSE JWTError m) ByteString
forall a b. a -> (a -> b) -> b
& Maybe ByteString -> MaybeT (JOSE JWTError m) ByteString
forall {a}. Maybe a -> MaybeT (JOSE JWTError m) a
hoistMaybe
      ByteString
compact <- Maybe ByteString -> MaybeT (JOSE JWTError m) ByteString
forall {a}. Maybe a -> MaybeT (JOSE JWTError m) a
hoistMaybe (Maybe ByteString -> MaybeT (JOSE JWTError m) ByteString)
-> Maybe ByteString -> MaybeT (JOSE JWTError m) ByteString
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString -> Maybe ByteString
Bytes.stripPrefix ByteString
prefix ByteString
authHeader
      SignedJWT
jwt <- ByteString -> MaybeT (JOSE JWTError m) SignedJWT
forall a e (m :: * -> *).
(FromCompact a, AsError e, MonadError e m) =>
ByteString -> m a
Jose.decodeCompact (ByteString -> MaybeT (JOSE JWTError m) SignedJWT)
-> ByteString -> MaybeT (JOSE JWTError m) SignedJWT
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
Bytes.Lazy.fromStrict ByteString
compact
      JWTValidationSettings
-> JWK
-> UTCTime
-> SignedJWT
-> MaybeT (JOSE JWTError m) (JwtPayload payload)
forall a e (m :: * -> *) payload k.
(HasAllowedSkew a, HasAudiencePredicate a, HasIssuerPredicate a,
 HasCheckIssuedAt a, HasValidationSettings a, AsError e,
 AsJWTError e, MonadError e m,
 VerificationKeyStore
   (ReaderT WrappedUTCTime m) (JWSHeader ()) payload k,
 HasClaimsSet payload, FromJSON payload) =>
a -> k -> UTCTime -> SignedJWT -> m payload
Jose.verifyJWTAt JWTValidationSettings
settings JWK
jwk_ UTCTime
now SignedJWT
jwt
    case Either JWTError (Maybe (JwtPayload payload))
payload of
      Left JWTError
e -> MakeError
-> Either JWTError (JwtPayload payload) -> m (AuthResult a)
check (ErrorBuilder
buildError Request
req) (JWTError -> Either JWTError (JwtPayload payload)
forall a b. a -> Either a b
Left JWTError
e)
      Right Maybe (JwtPayload payload)
Nothing -> AuthResult a -> m (AuthResult a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure AuthResult a
forall a. AuthResult a
Indefinite
      Right (Just (JwtPayload payload
v :: JwtPayload payload)) -> MakeError
-> Either JWTError (JwtPayload payload) -> m (AuthResult a)
check (ErrorBuilder
buildError Request
req) (JwtPayload payload -> Either JWTError (JwtPayload payload)
forall a b. b -> Either a b
Right JwtPayload payload
v)
   where
    prefix :: ByteString
prefix = ByteString
"Bearer "

    hoistMaybe :: Maybe a -> MaybeT (JOSE JWTError m) a
hoistMaybe = JOSE JWTError m (Maybe a) -> MaybeT (JOSE JWTError m) a
forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT (JOSE JWTError m (Maybe a) -> MaybeT (JOSE JWTError m) a)
-> (Maybe a -> JOSE JWTError m (Maybe a))
-> Maybe a
-> MaybeT (JOSE JWTError m) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe a -> JOSE JWTError m (Maybe a)
forall a. a -> JOSE JWTError m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure