module Web.Minion.Auth.Jwt where
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Trans.Maybe (MaybeT (..))
import Crypto.JOSE qualified as Jose
import Crypto.JWT (JWTError)
import Crypto.JWT qualified as Jose
import Data.Aeson (FromJSON (..))
import Data.ByteString qualified as Bytes
import Data.ByteString.Lazy qualified as Bytes.Lazy
import Data.Function ((&))
import Data.Functor ((<&>))
import Data.Time qualified as Time
import Network.HTTP.Types.Header qualified as Http
import Network.Wai qualified as Wai
import Web.Minion
data JwtAuthSettings m payload a = JwtAuthSettings
{ forall (m :: * -> *) payload a.
JwtAuthSettings m payload a -> m UTCTime
getNow :: m Time.UTCTime
, forall (m :: * -> *) payload a.
JwtAuthSettings m payload a -> m JWK
jwk :: m Jose.JWK
, forall (m :: * -> *) payload a.
JwtAuthSettings m payload a -> m JWTValidationSettings
validationSettings :: m Jose.JWTValidationSettings
, forall (m :: * -> *) payload a.
JwtAuthSettings m payload a
-> MakeError
-> Either JWTError (JwtPayload payload)
-> m (AuthResult a)
check :: MakeError -> Either JWTError (JwtPayload payload) -> m (AuthResult a)
}
defaultJwtAuthSettings ::
(MonadIO m) =>
m Jose.JWK ->
(Jose.StringOrURI -> Bool) ->
(MakeError -> Either JWTError (JwtPayload payload) -> m (AuthResult a)) ->
JwtAuthSettings m payload a
defaultJwtAuthSettings :: forall (m :: * -> *) payload a.
MonadIO m =>
m JWK
-> (StringOrURI -> Bool)
-> (MakeError
-> Either JWTError (JwtPayload payload) -> m (AuthResult a))
-> JwtAuthSettings m payload a
defaultJwtAuthSettings m JWK
jwk StringOrURI -> Bool
audCheck MakeError
-> Either JWTError (JwtPayload payload) -> m (AuthResult a)
check =
JwtAuthSettings
{ $sel:getNow:JwtAuthSettings :: m UTCTime
getNow = IO UTCTime -> m UTCTime
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO UTCTime
Time.getCurrentTime
, $sel:jwk:JwtAuthSettings :: m JWK
jwk = m JWK
jwk
, $sel:validationSettings:JwtAuthSettings :: m JWTValidationSettings
validationSettings = JWTValidationSettings -> m JWTValidationSettings
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((StringOrURI -> Bool) -> JWTValidationSettings
Jose.defaultJWTValidationSettings StringOrURI -> Bool
audCheck)
, $sel:check:JwtAuthSettings :: MakeError
-> Either JWTError (JwtPayload payload) -> m (AuthResult a)
check = MakeError
-> Either JWTError (JwtPayload payload) -> m (AuthResult a)
check
}
data JwtPayload a = JwtPayload
{ forall a. JwtPayload a -> ClaimsSet
claims :: Jose.ClaimsSet
, forall a. JwtPayload a -> a
payload :: a
}
instance Jose.HasClaimsSet (JwtPayload a) where
claimsSet :: Lens' (JwtPayload a) ClaimsSet
claimsSet ClaimsSet -> f ClaimsSet
f JwtPayload{a
ClaimsSet
$sel:claims:JwtPayload :: forall a. JwtPayload a -> ClaimsSet
$sel:payload:JwtPayload :: forall a. JwtPayload a -> a
claims :: ClaimsSet
payload :: a
..} = ClaimsSet -> f ClaimsSet
f ClaimsSet
claims f ClaimsSet -> (ClaimsSet -> JwtPayload a) -> f (JwtPayload a)
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \ClaimsSet
c -> JwtPayload{$sel:claims:JwtPayload :: ClaimsSet
claims = ClaimsSet
c, a
$sel:payload:JwtPayload :: a
payload :: a
..}
instance (FromJSON a) => FromJSON (JwtPayload a) where
parseJSON :: Value -> Parser (JwtPayload a)
parseJSON Value
v =
ClaimsSet -> a -> JwtPayload a
forall a. ClaimsSet -> a -> JwtPayload a
JwtPayload
(ClaimsSet -> a -> JwtPayload a)
-> Parser ClaimsSet -> Parser (a -> JwtPayload a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Value -> Parser ClaimsSet
forall a. FromJSON a => Value -> Parser a
parseJSON Value
v
Parser (a -> JwtPayload a) -> Parser a -> Parser (JwtPayload a)
forall a b. Parser (a -> b) -> Parser a -> Parser b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Value -> Parser a
forall a. FromJSON a => Value -> Parser a
parseJSON Value
v
data Bearer payload
instance (MonadIO m, FromJSON payload) => IsAuth (Bearer payload) m a where
type Settings (Bearer payload) m a = JwtAuthSettings m payload a
toAuth :: Settings (Bearer payload) m a
-> ErrorBuilder -> Request -> m (AuthResult a)
toAuth JwtAuthSettings{m UTCTime
m JWK
m JWTValidationSettings
MakeError
-> Either JWTError (JwtPayload payload) -> m (AuthResult a)
$sel:getNow:JwtAuthSettings :: forall (m :: * -> *) payload a.
JwtAuthSettings m payload a -> m UTCTime
$sel:jwk:JwtAuthSettings :: forall (m :: * -> *) payload a.
JwtAuthSettings m payload a -> m JWK
$sel:validationSettings:JwtAuthSettings :: forall (m :: * -> *) payload a.
JwtAuthSettings m payload a -> m JWTValidationSettings
$sel:check:JwtAuthSettings :: forall (m :: * -> *) payload a.
JwtAuthSettings m payload a
-> MakeError
-> Either JWTError (JwtPayload payload)
-> m (AuthResult a)
getNow :: m UTCTime
jwk :: m JWK
validationSettings :: m JWTValidationSettings
check :: MakeError
-> Either JWTError (JwtPayload payload) -> m (AuthResult a)
..} ErrorBuilder
buildError Request
req = do
JWK
jwk_ <- m JWK
jwk
UTCTime
now <- m UTCTime
getNow
JWTValidationSettings
settings <- m JWTValidationSettings
validationSettings
Either JWTError (Maybe (JwtPayload payload))
payload <- JOSE JWTError m (Maybe (JwtPayload payload))
-> m (Either JWTError (Maybe (JwtPayload payload)))
forall e (m :: * -> *) a. JOSE e m a -> m (Either e a)
Jose.runJOSE (JOSE JWTError m (Maybe (JwtPayload payload))
-> m (Either JWTError (Maybe (JwtPayload payload))))
-> JOSE JWTError m (Maybe (JwtPayload payload))
-> m (Either JWTError (Maybe (JwtPayload payload)))
forall a b. (a -> b) -> a -> b
$ MaybeT (JOSE JWTError m) (JwtPayload payload)
-> JOSE JWTError m (Maybe (JwtPayload payload))
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT do
ByteString
authHeader <- Request -> RequestHeaders
Wai.requestHeaders Request
req RequestHeaders
-> (RequestHeaders -> Maybe ByteString) -> Maybe ByteString
forall a b. a -> (a -> b) -> b
& HeaderName -> RequestHeaders -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
Http.hAuthorization Maybe ByteString
-> (Maybe ByteString -> MaybeT (JOSE JWTError m) ByteString)
-> MaybeT (JOSE JWTError m) ByteString
forall a b. a -> (a -> b) -> b
& Maybe ByteString -> MaybeT (JOSE JWTError m) ByteString
forall {a}. Maybe a -> MaybeT (JOSE JWTError m) a
hoistMaybe
ByteString
compact <- Maybe ByteString -> MaybeT (JOSE JWTError m) ByteString
forall {a}. Maybe a -> MaybeT (JOSE JWTError m) a
hoistMaybe (Maybe ByteString -> MaybeT (JOSE JWTError m) ByteString)
-> Maybe ByteString -> MaybeT (JOSE JWTError m) ByteString
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString -> Maybe ByteString
Bytes.stripPrefix ByteString
prefix ByteString
authHeader
SignedJWT
jwt <- ByteString -> MaybeT (JOSE JWTError m) SignedJWT
forall a e (m :: * -> *).
(FromCompact a, AsError e, MonadError e m) =>
ByteString -> m a
Jose.decodeCompact (ByteString -> MaybeT (JOSE JWTError m) SignedJWT)
-> ByteString -> MaybeT (JOSE JWTError m) SignedJWT
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
Bytes.Lazy.fromStrict ByteString
compact
JWTValidationSettings
-> JWK
-> UTCTime
-> SignedJWT
-> MaybeT (JOSE JWTError m) (JwtPayload payload)
forall a e (m :: * -> *) payload k.
(HasAllowedSkew a, HasAudiencePredicate a, HasIssuerPredicate a,
HasCheckIssuedAt a, HasValidationSettings a, AsError e,
AsJWTError e, MonadError e m,
VerificationKeyStore
(ReaderT WrappedUTCTime m) (JWSHeader ()) payload k,
HasClaimsSet payload, FromJSON payload) =>
a -> k -> UTCTime -> SignedJWT -> m payload
Jose.verifyJWTAt JWTValidationSettings
settings JWK
jwk_ UTCTime
now SignedJWT
jwt
case Either JWTError (Maybe (JwtPayload payload))
payload of
Left JWTError
e -> MakeError
-> Either JWTError (JwtPayload payload) -> m (AuthResult a)
check (ErrorBuilder
buildError Request
req) (JWTError -> Either JWTError (JwtPayload payload)
forall a b. a -> Either a b
Left JWTError
e)
Right Maybe (JwtPayload payload)
Nothing -> AuthResult a -> m (AuthResult a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure AuthResult a
forall a. AuthResult a
Indefinite
Right (Just (JwtPayload payload
v :: JwtPayload payload)) -> MakeError
-> Either JWTError (JwtPayload payload) -> m (AuthResult a)
check (ErrorBuilder
buildError Request
req) (JwtPayload payload -> Either JWTError (JwtPayload payload)
forall a b. b -> Either a b
Right JwtPayload payload
v)
where
prefix :: ByteString
prefix = ByteString
"Bearer "
hoistMaybe :: Maybe a -> MaybeT (JOSE JWTError m) a
hoistMaybe = JOSE JWTError m (Maybe a) -> MaybeT (JOSE JWTError m) a
forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT (JOSE JWTError m (Maybe a) -> MaybeT (JOSE JWTError m) a)
-> (Maybe a -> JOSE JWTError m (Maybe a))
-> Maybe a
-> MaybeT (JOSE JWTError m) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe a -> JOSE JWTError m (Maybe a)
forall a. a -> JOSE JWTError m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure