{-# LANGUAGE OverloadedStrings #-}
module JwtAuth where
import Control.Monad ((<=<))
import qualified Data.Aeson as Aeson
import qualified Data.Aeson.Types as Aeson
import Data.Bifunctor (first)
import qualified Data.ByteString as SBS
import qualified Data.Map.Strict as Map
import qualified Data.Text.Encoding as Text
import Data.Time.Clock.POSIX (POSIXTime)
import Web.JWT (JWT, UnverifiedJWT, VerifiedJWT)
import qualified Web.JWT as JWT
import AccessControl
data VerificationError
= TokenUsedTooEarly
| TokenExpired
| TokenInvalid
| TokenNotFound
| TokenSignatureInvalid
deriving (Int -> VerificationError -> ShowS
[VerificationError] -> ShowS
VerificationError -> String
(Int -> VerificationError -> ShowS)
-> (VerificationError -> String)
-> ([VerificationError] -> ShowS)
-> Show VerificationError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [VerificationError] -> ShowS
$cshowList :: [VerificationError] -> ShowS
show :: VerificationError -> String
$cshow :: VerificationError -> String
showsPrec :: Int -> VerificationError -> ShowS
$cshowsPrec :: Int -> VerificationError -> ShowS
Show, VerificationError -> VerificationError -> Bool
(VerificationError -> VerificationError -> Bool)
-> (VerificationError -> VerificationError -> Bool)
-> Eq VerificationError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: VerificationError -> VerificationError -> Bool
$c/= :: VerificationError -> VerificationError -> Bool
== :: VerificationError -> VerificationError -> Bool
$c== :: VerificationError -> VerificationError -> Bool
Eq)
verifyToken :: POSIXTime -> JWT.Signer -> SBS.ByteString -> Either VerificationError (JWT VerifiedJWT)
verifyToken :: POSIXTime
-> Signer
-> ByteString
-> Either VerificationError (JWT VerifiedJWT)
verifyToken POSIXTime
now Signer
secret = POSIXTime
-> JWT VerifiedJWT -> Either VerificationError (JWT VerifiedJWT)
verifyNotBefore POSIXTime
now
(JWT VerifiedJWT -> Either VerificationError (JWT VerifiedJWT))
-> (ByteString -> Either VerificationError (JWT VerifiedJWT))
-> ByteString
-> Either VerificationError (JWT VerifiedJWT)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< POSIXTime
-> JWT VerifiedJWT -> Either VerificationError (JWT VerifiedJWT)
verifyExpiry POSIXTime
now
(JWT VerifiedJWT -> Either VerificationError (JWT VerifiedJWT))
-> (ByteString -> Either VerificationError (JWT VerifiedJWT))
-> ByteString
-> Either VerificationError (JWT VerifiedJWT)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Signer
-> JWT UnverifiedJWT -> Either VerificationError (JWT VerifiedJWT)
verifySignature Signer
secret
(JWT UnverifiedJWT -> Either VerificationError (JWT VerifiedJWT))
-> (ByteString -> Either VerificationError (JWT UnverifiedJWT))
-> ByteString
-> Either VerificationError (JWT VerifiedJWT)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< ByteString -> Either VerificationError (JWT UnverifiedJWT)
decodeToken
verifyNotBefore :: POSIXTime -> JWT VerifiedJWT -> Either VerificationError (JWT VerifiedJWT)
verifyNotBefore :: POSIXTime
-> JWT VerifiedJWT -> Either VerificationError (JWT VerifiedJWT)
verifyNotBefore POSIXTime
now JWT VerifiedJWT
token =
case JWTClaimsSet -> Maybe IntDate
JWT.nbf (JWTClaimsSet -> Maybe IntDate)
-> (JWT VerifiedJWT -> JWTClaimsSet)
-> JWT VerifiedJWT
-> Maybe IntDate
forall b c a. (b -> c) -> (a -> b) -> a -> c
. JWT VerifiedJWT -> JWTClaimsSet
forall r. JWT r -> JWTClaimsSet
JWT.claims (JWT VerifiedJWT -> Maybe IntDate)
-> JWT VerifiedJWT -> Maybe IntDate
forall a b. (a -> b) -> a -> b
$ JWT VerifiedJWT
token of
Maybe IntDate
Nothing -> JWT VerifiedJWT -> Either VerificationError (JWT VerifiedJWT)
forall a b. b -> Either a b
Right JWT VerifiedJWT
token
Just IntDate
notBefore ->
if POSIXTime
now POSIXTime -> POSIXTime -> Bool
forall a. Ord a => a -> a -> Bool
<= IntDate -> POSIXTime
JWT.secondsSinceEpoch IntDate
notBefore
then VerificationError -> Either VerificationError (JWT VerifiedJWT)
forall a b. a -> Either a b
Left VerificationError
TokenUsedTooEarly
else JWT VerifiedJWT -> Either VerificationError (JWT VerifiedJWT)
forall a b. b -> Either a b
Right JWT VerifiedJWT
token
verifyExpiry :: POSIXTime -> JWT VerifiedJWT -> Either VerificationError (JWT VerifiedJWT)
verifyExpiry :: POSIXTime
-> JWT VerifiedJWT -> Either VerificationError (JWT VerifiedJWT)
verifyExpiry POSIXTime
now JWT VerifiedJWT
token =
case JWTClaimsSet -> Maybe IntDate
JWT.exp (JWTClaimsSet -> Maybe IntDate)
-> (JWT VerifiedJWT -> JWTClaimsSet)
-> JWT VerifiedJWT
-> Maybe IntDate
forall b c a. (b -> c) -> (a -> b) -> a -> c
. JWT VerifiedJWT -> JWTClaimsSet
forall r. JWT r -> JWTClaimsSet
JWT.claims (JWT VerifiedJWT -> Maybe IntDate)
-> JWT VerifiedJWT -> Maybe IntDate
forall a b. (a -> b) -> a -> b
$ JWT VerifiedJWT
token of
Maybe IntDate
Nothing -> JWT VerifiedJWT -> Either VerificationError (JWT VerifiedJWT)
forall a b. b -> Either a b
Right JWT VerifiedJWT
token
Just IntDate
expiry ->
if POSIXTime
now POSIXTime -> POSIXTime -> Bool
forall a. Ord a => a -> a -> Bool
> IntDate -> POSIXTime
JWT.secondsSinceEpoch IntDate
expiry
then VerificationError -> Either VerificationError (JWT VerifiedJWT)
forall a b. a -> Either a b
Left VerificationError
TokenExpired
else JWT VerifiedJWT -> Either VerificationError (JWT VerifiedJWT)
forall a b. b -> Either a b
Right JWT VerifiedJWT
token
verifySignature :: JWT.Signer -> JWT UnverifiedJWT -> Either VerificationError (JWT VerifiedJWT)
verifySignature :: Signer
-> JWT UnverifiedJWT -> Either VerificationError (JWT VerifiedJWT)
verifySignature Signer
secret JWT UnverifiedJWT
token =
case Signer -> JWT UnverifiedJWT -> Maybe (JWT VerifiedJWT)
JWT.verify Signer
secret JWT UnverifiedJWT
token of
Maybe (JWT VerifiedJWT)
Nothing -> VerificationError -> Either VerificationError (JWT VerifiedJWT)
forall a b. a -> Either a b
Left VerificationError
TokenSignatureInvalid
Just JWT VerifiedJWT
token' -> JWT VerifiedJWT -> Either VerificationError (JWT VerifiedJWT)
forall a b. b -> Either a b
Right JWT VerifiedJWT
token'
decodeToken :: SBS.ByteString -> Either VerificationError (JWT UnverifiedJWT)
decodeToken :: ByteString -> Either VerificationError (JWT UnverifiedJWT)
decodeToken ByteString
bytes =
case Text -> Maybe (JWT UnverifiedJWT)
JWT.decode (ByteString -> Text
Text.decodeUtf8 ByteString
bytes) of
Maybe (JWT UnverifiedJWT)
Nothing -> VerificationError -> Either VerificationError (JWT UnverifiedJWT)
forall a b. a -> Either a b
Left VerificationError
TokenInvalid
Just JWT UnverifiedJWT
token -> JWT UnverifiedJWT -> Either VerificationError (JWT UnverifiedJWT)
forall a b. b -> Either a b
Right JWT UnverifiedJWT
token
data TokenError
= VerificationError VerificationError
| ClaimError String
deriving (Int -> TokenError -> ShowS
[TokenError] -> ShowS
TokenError -> String
(Int -> TokenError -> ShowS)
-> (TokenError -> String)
-> ([TokenError] -> ShowS)
-> Show TokenError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TokenError] -> ShowS
$cshowList :: [TokenError] -> ShowS
show :: TokenError -> String
$cshow :: TokenError -> String
showsPrec :: Int -> TokenError -> ShowS
$cshowsPrec :: Int -> TokenError -> ShowS
Show, TokenError -> TokenError -> Bool
(TokenError -> TokenError -> Bool)
-> (TokenError -> TokenError -> Bool) -> Eq TokenError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TokenError -> TokenError -> Bool
$c/= :: TokenError -> TokenError -> Bool
== :: TokenError -> TokenError -> Bool
$c== :: TokenError -> TokenError -> Bool
Eq)
extractClaim :: POSIXTime -> JWT.Signer -> SBS.ByteString -> Either TokenError IcepeakClaim
POSIXTime
now Signer
secret ByteString
tokenBytes = do
JWT VerifiedJWT
jwt <- (VerificationError -> TokenError)
-> Either VerificationError (JWT VerifiedJWT)
-> Either TokenError (JWT VerifiedJWT)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first VerificationError -> TokenError
VerificationError (Either VerificationError (JWT VerifiedJWT)
-> Either TokenError (JWT VerifiedJWT))
-> Either VerificationError (JWT VerifiedJWT)
-> Either TokenError (JWT VerifiedJWT)
forall a b. (a -> b) -> a -> b
$ POSIXTime
-> Signer
-> ByteString
-> Either VerificationError (JWT VerifiedJWT)
verifyToken POSIXTime
now Signer
secret ByteString
tokenBytes
IcepeakClaim
claim <- (String -> TokenError)
-> Either String IcepeakClaim -> Either TokenError IcepeakClaim
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first String -> TokenError
ClaimError (Either String IcepeakClaim -> Either TokenError IcepeakClaim)
-> Either String IcepeakClaim -> Either TokenError IcepeakClaim
forall a b. (a -> b) -> a -> b
$ JWT VerifiedJWT -> Either String IcepeakClaim
forall r. JWT r -> Either String IcepeakClaim
getIcepeakClaim JWT VerifiedJWT
jwt
IcepeakClaim -> Either TokenError IcepeakClaim
forall (f :: * -> *) a. Applicative f => a -> f a
pure IcepeakClaim
claim
extractClaimUnverified :: SBS.ByteString -> Either TokenError IcepeakClaim
ByteString
tokenBytes = do
JWT UnverifiedJWT
jwt <- (VerificationError -> TokenError)
-> Either VerificationError (JWT UnverifiedJWT)
-> Either TokenError (JWT UnverifiedJWT)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first VerificationError -> TokenError
VerificationError (Either VerificationError (JWT UnverifiedJWT)
-> Either TokenError (JWT UnverifiedJWT))
-> Either VerificationError (JWT UnverifiedJWT)
-> Either TokenError (JWT UnverifiedJWT)
forall a b. (a -> b) -> a -> b
$ ByteString -> Either VerificationError (JWT UnverifiedJWT)
decodeToken ByteString
tokenBytes
IcepeakClaim
claim <- (String -> TokenError)
-> Either String IcepeakClaim -> Either TokenError IcepeakClaim
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first String -> TokenError
ClaimError (Either String IcepeakClaim -> Either TokenError IcepeakClaim)
-> Either String IcepeakClaim -> Either TokenError IcepeakClaim
forall a b. (a -> b) -> a -> b
$ JWT UnverifiedJWT -> Either String IcepeakClaim
forall r. JWT r -> Either String IcepeakClaim
getIcepeakClaim JWT UnverifiedJWT
jwt
IcepeakClaim -> Either TokenError IcepeakClaim
forall (f :: * -> *) a. Applicative f => a -> f a
pure IcepeakClaim
claim
getIcepeakClaim :: JWT r -> Either String IcepeakClaim
getIcepeakClaim :: JWT r -> Either String IcepeakClaim
getIcepeakClaim JWT r
token = do
let (JWT.ClaimsMap Map Text Value
claimsMap) = JWTClaimsSet -> ClaimsMap
JWT.unregisteredClaims (JWTClaimsSet -> ClaimsMap) -> JWTClaimsSet -> ClaimsMap
forall a b. (a -> b) -> a -> b
$ JWT r -> JWTClaimsSet
forall r. JWT r -> JWTClaimsSet
JWT.claims JWT r
token
maybeClaim :: Maybe Value
maybeClaim = Text -> Map Text Value -> Maybe Value
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Text
"icepeak" Map Text Value
claimsMap
Value
claimJson <- Either String Value
-> (Value -> Either String Value)
-> Maybe Value
-> Either String Value
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (String -> Either String Value
forall a b. a -> Either a b
Left String
"Icepeak claim missing.") Value -> Either String Value
forall a b. b -> Either a b
Right Maybe Value
maybeClaim
(Value -> Parser IcepeakClaim)
-> Value -> Either String IcepeakClaim
forall a b. (a -> Parser b) -> a -> Either String b
Aeson.parseEither Value -> Parser IcepeakClaim
forall a. FromJSON a => Value -> Parser a
Aeson.parseJSON Value
claimJson
addIcepeakClaim :: IcepeakClaim -> JWT.JWTClaimsSet -> JWT.JWTClaimsSet
addIcepeakClaim :: IcepeakClaim -> JWTClaimsSet -> JWTClaimsSet
addIcepeakClaim IcepeakClaim
claim JWTClaimsSet
claims = JWTClaimsSet
claims
{ unregisteredClaims :: ClaimsMap
JWT.unregisteredClaims = ClaimsMap
newClaimsMap ClaimsMap -> ClaimsMap -> ClaimsMap
forall a. Semigroup a => a -> a -> a
<> JWTClaimsSet -> ClaimsMap
JWT.unregisteredClaims JWTClaimsSet
claims }
where
newClaimsMap :: ClaimsMap
newClaimsMap = Map Text Value -> ClaimsMap
JWT.ClaimsMap (Map Text Value -> ClaimsMap) -> Map Text Value -> ClaimsMap
forall a b. (a -> b) -> a -> b
$ [(Text, Value)] -> Map Text Value
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [(Text
"icepeak", IcepeakClaim -> Value
forall a. ToJSON a => a -> Value
Aeson.toJSON IcepeakClaim
claim)]