{-# LANGUAGE OverloadedStrings #-}
-- | This module contains all the web framework independent code for parsing and verifying
-- JSON Web Tokens.
module JwtAuth where

import           Control.Monad         ((<=<))
import qualified Data.Aeson            as Aeson
import qualified Data.Aeson.Types      as Aeson
import           Data.Bifunctor        (first)
import qualified Data.ByteString       as SBS
import qualified Data.Map.Strict       as Map
import qualified Data.Text.Encoding    as Text
import           Data.Time.Clock.POSIX (POSIXTime)
import           Web.JWT               (JWT, UnverifiedJWT, VerifiedJWT)
import qualified Web.JWT               as JWT

import           AccessControl

-- * Token verification

data VerificationError
  = TokenUsedTooEarly
  | TokenExpired
  | TokenInvalid
  | TokenNotFound
  | TokenSignatureInvalid
  deriving (Int -> VerificationError -> ShowS
[VerificationError] -> ShowS
VerificationError -> String
(Int -> VerificationError -> ShowS)
-> (VerificationError -> String)
-> ([VerificationError] -> ShowS)
-> Show VerificationError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [VerificationError] -> ShowS
$cshowList :: [VerificationError] -> ShowS
show :: VerificationError -> String
$cshow :: VerificationError -> String
showsPrec :: Int -> VerificationError -> ShowS
$cshowsPrec :: Int -> VerificationError -> ShowS
Show, VerificationError -> VerificationError -> Bool
(VerificationError -> VerificationError -> Bool)
-> (VerificationError -> VerificationError -> Bool)
-> Eq VerificationError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: VerificationError -> VerificationError -> Bool
$c/= :: VerificationError -> VerificationError -> Bool
== :: VerificationError -> VerificationError -> Bool
$c== :: VerificationError -> VerificationError -> Bool
Eq)

-- | Check that a token is valid at the given time for the given secret.
verifyToken :: POSIXTime -> JWT.Signer -> SBS.ByteString -> Either VerificationError (JWT VerifiedJWT)
verifyToken :: POSIXTime
-> Signer
-> ByteString
-> Either VerificationError (JWT VerifiedJWT)
verifyToken POSIXTime
now Signer
secret = POSIXTime
-> JWT VerifiedJWT -> Either VerificationError (JWT VerifiedJWT)
verifyNotBefore POSIXTime
now
                       (JWT VerifiedJWT -> Either VerificationError (JWT VerifiedJWT))
-> (ByteString -> Either VerificationError (JWT VerifiedJWT))
-> ByteString
-> Either VerificationError (JWT VerifiedJWT)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< POSIXTime
-> JWT VerifiedJWT -> Either VerificationError (JWT VerifiedJWT)
verifyExpiry POSIXTime
now
                       (JWT VerifiedJWT -> Either VerificationError (JWT VerifiedJWT))
-> (ByteString -> Either VerificationError (JWT VerifiedJWT))
-> ByteString
-> Either VerificationError (JWT VerifiedJWT)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Signer
-> JWT UnverifiedJWT -> Either VerificationError (JWT VerifiedJWT)
verifySignature Signer
secret
                       (JWT UnverifiedJWT -> Either VerificationError (JWT VerifiedJWT))
-> (ByteString -> Either VerificationError (JWT UnverifiedJWT))
-> ByteString
-> Either VerificationError (JWT VerifiedJWT)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< ByteString -> Either VerificationError (JWT UnverifiedJWT)
decodeToken

-- | Verify that the token is not used before it was issued.
verifyNotBefore :: POSIXTime -> JWT VerifiedJWT -> Either VerificationError (JWT VerifiedJWT)
verifyNotBefore :: POSIXTime
-> JWT VerifiedJWT -> Either VerificationError (JWT VerifiedJWT)
verifyNotBefore POSIXTime
now JWT VerifiedJWT
token =
 case JWTClaimsSet -> Maybe IntDate
JWT.nbf (JWTClaimsSet -> Maybe IntDate)
-> (JWT VerifiedJWT -> JWTClaimsSet)
-> JWT VerifiedJWT
-> Maybe IntDate
forall b c a. (b -> c) -> (a -> b) -> a -> c
. JWT VerifiedJWT -> JWTClaimsSet
forall r. JWT r -> JWTClaimsSet
JWT.claims (JWT VerifiedJWT -> Maybe IntDate)
-> JWT VerifiedJWT -> Maybe IntDate
forall a b. (a -> b) -> a -> b
$ JWT VerifiedJWT
token of
   Maybe IntDate
Nothing -> JWT VerifiedJWT -> Either VerificationError (JWT VerifiedJWT)
forall a b. b -> Either a b
Right JWT VerifiedJWT
token
   Just IntDate
notBefore ->
     if POSIXTime
now POSIXTime -> POSIXTime -> Bool
forall a. Ord a => a -> a -> Bool
<= IntDate -> POSIXTime
JWT.secondsSinceEpoch IntDate
notBefore
       then VerificationError -> Either VerificationError (JWT VerifiedJWT)
forall a b. a -> Either a b
Left VerificationError
TokenUsedTooEarly
       else JWT VerifiedJWT -> Either VerificationError (JWT VerifiedJWT)
forall a b. b -> Either a b
Right JWT VerifiedJWT
token

-- | Verify that the token is not used after is has expired.
verifyExpiry :: POSIXTime -> JWT VerifiedJWT -> Either VerificationError (JWT VerifiedJWT)
verifyExpiry :: POSIXTime
-> JWT VerifiedJWT -> Either VerificationError (JWT VerifiedJWT)
verifyExpiry POSIXTime
now JWT VerifiedJWT
token =
 case JWTClaimsSet -> Maybe IntDate
JWT.exp (JWTClaimsSet -> Maybe IntDate)
-> (JWT VerifiedJWT -> JWTClaimsSet)
-> JWT VerifiedJWT
-> Maybe IntDate
forall b c a. (b -> c) -> (a -> b) -> a -> c
. JWT VerifiedJWT -> JWTClaimsSet
forall r. JWT r -> JWTClaimsSet
JWT.claims (JWT VerifiedJWT -> Maybe IntDate)
-> JWT VerifiedJWT -> Maybe IntDate
forall a b. (a -> b) -> a -> b
$ JWT VerifiedJWT
token of
   Maybe IntDate
Nothing -> JWT VerifiedJWT -> Either VerificationError (JWT VerifiedJWT)
forall a b. b -> Either a b
Right JWT VerifiedJWT
token
   Just IntDate
expiry ->
     if POSIXTime
now POSIXTime -> POSIXTime -> Bool
forall a. Ord a => a -> a -> Bool
> IntDate -> POSIXTime
JWT.secondsSinceEpoch IntDate
expiry
       then VerificationError -> Either VerificationError (JWT VerifiedJWT)
forall a b. a -> Either a b
Left VerificationError
TokenExpired
       else JWT VerifiedJWT -> Either VerificationError (JWT VerifiedJWT)
forall a b. b -> Either a b
Right JWT VerifiedJWT
token

-- | Verify that the token contains a valid signature.
verifySignature :: JWT.Signer -> JWT UnverifiedJWT -> Either VerificationError (JWT VerifiedJWT)
verifySignature :: Signer
-> JWT UnverifiedJWT -> Either VerificationError (JWT VerifiedJWT)
verifySignature Signer
secret JWT UnverifiedJWT
token =
 case Signer -> JWT UnverifiedJWT -> Maybe (JWT VerifiedJWT)
JWT.verify Signer
secret JWT UnverifiedJWT
token of
   Maybe (JWT VerifiedJWT)
Nothing     -> VerificationError -> Either VerificationError (JWT VerifiedJWT)
forall a b. a -> Either a b
Left VerificationError
TokenSignatureInvalid
   Just JWT VerifiedJWT
token' -> JWT VerifiedJWT -> Either VerificationError (JWT VerifiedJWT)
forall a b. b -> Either a b
Right JWT VerifiedJWT
token'

decodeToken :: SBS.ByteString -> Either VerificationError (JWT UnverifiedJWT)
decodeToken :: ByteString -> Either VerificationError (JWT UnverifiedJWT)
decodeToken ByteString
bytes =
 case Text -> Maybe (JWT UnverifiedJWT)
JWT.decode (ByteString -> Text
Text.decodeUtf8 ByteString
bytes) of
   Maybe (JWT UnverifiedJWT)
Nothing    -> VerificationError -> Either VerificationError (JWT UnverifiedJWT)
forall a b. a -> Either a b
Left VerificationError
TokenInvalid
   Just JWT UnverifiedJWT
token -> JWT UnverifiedJWT -> Either VerificationError (JWT UnverifiedJWT)
forall a b. b -> Either a b
Right JWT UnverifiedJWT
token

-- * Claim parsing

data TokenError
  = VerificationError VerificationError -- ^ JWT could not be verified.
  | ClaimError String                   -- ^ The claims do not fit the schema.
  deriving (Int -> TokenError -> ShowS
[TokenError] -> ShowS
TokenError -> String
(Int -> TokenError -> ShowS)
-> (TokenError -> String)
-> ([TokenError] -> ShowS)
-> Show TokenError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TokenError] -> ShowS
$cshowList :: [TokenError] -> ShowS
show :: TokenError -> String
$cshow :: TokenError -> String
showsPrec :: Int -> TokenError -> ShowS
$cshowsPrec :: Int -> TokenError -> ShowS
Show, TokenError -> TokenError -> Bool
(TokenError -> TokenError -> Bool)
-> (TokenError -> TokenError -> Bool) -> Eq TokenError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TokenError -> TokenError -> Bool
$c/= :: TokenError -> TokenError -> Bool
== :: TokenError -> TokenError -> Bool
$c== :: TokenError -> TokenError -> Bool
Eq)

-- | Verify the token and extract the icepeak claim from it.
extractClaim :: POSIXTime -> JWT.Signer -> SBS.ByteString -> Either TokenError IcepeakClaim
extractClaim :: POSIXTime -> Signer -> ByteString -> Either TokenError IcepeakClaim
extractClaim POSIXTime
now Signer
secret ByteString
tokenBytes = do
  JWT VerifiedJWT
jwt <- (VerificationError -> TokenError)
-> Either VerificationError (JWT VerifiedJWT)
-> Either TokenError (JWT VerifiedJWT)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first VerificationError -> TokenError
VerificationError (Either VerificationError (JWT VerifiedJWT)
 -> Either TokenError (JWT VerifiedJWT))
-> Either VerificationError (JWT VerifiedJWT)
-> Either TokenError (JWT VerifiedJWT)
forall a b. (a -> b) -> a -> b
$ POSIXTime
-> Signer
-> ByteString
-> Either VerificationError (JWT VerifiedJWT)
verifyToken POSIXTime
now Signer
secret ByteString
tokenBytes
  IcepeakClaim
claim <- (String -> TokenError)
-> Either String IcepeakClaim -> Either TokenError IcepeakClaim
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first String -> TokenError
ClaimError (Either String IcepeakClaim -> Either TokenError IcepeakClaim)
-> Either String IcepeakClaim -> Either TokenError IcepeakClaim
forall a b. (a -> b) -> a -> b
$ JWT VerifiedJWT -> Either String IcepeakClaim
forall r. JWT r -> Either String IcepeakClaim
getIcepeakClaim JWT VerifiedJWT
jwt
  IcepeakClaim -> Either TokenError IcepeakClaim
forall (f :: * -> *) a. Applicative f => a -> f a
pure IcepeakClaim
claim

-- | Extract the icepeak claim from the token without verifying it.
extractClaimUnverified :: SBS.ByteString -> Either TokenError IcepeakClaim
extractClaimUnverified :: ByteString -> Either TokenError IcepeakClaim
extractClaimUnverified ByteString
tokenBytes = do
  JWT UnverifiedJWT
jwt <- (VerificationError -> TokenError)
-> Either VerificationError (JWT UnverifiedJWT)
-> Either TokenError (JWT UnverifiedJWT)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first VerificationError -> TokenError
VerificationError (Either VerificationError (JWT UnverifiedJWT)
 -> Either TokenError (JWT UnverifiedJWT))
-> Either VerificationError (JWT UnverifiedJWT)
-> Either TokenError (JWT UnverifiedJWT)
forall a b. (a -> b) -> a -> b
$ ByteString -> Either VerificationError (JWT UnverifiedJWT)
decodeToken ByteString
tokenBytes
  IcepeakClaim
claim <- (String -> TokenError)
-> Either String IcepeakClaim -> Either TokenError IcepeakClaim
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first String -> TokenError
ClaimError (Either String IcepeakClaim -> Either TokenError IcepeakClaim)
-> Either String IcepeakClaim -> Either TokenError IcepeakClaim
forall a b. (a -> b) -> a -> b
$ JWT UnverifiedJWT -> Either String IcepeakClaim
forall r. JWT r -> Either String IcepeakClaim
getIcepeakClaim JWT UnverifiedJWT
jwt
  IcepeakClaim -> Either TokenError IcepeakClaim
forall (f :: * -> *) a. Applicative f => a -> f a
pure IcepeakClaim
claim

getIcepeakClaim :: JWT r -> Either String IcepeakClaim
getIcepeakClaim :: JWT r -> Either String IcepeakClaim
getIcepeakClaim JWT r
token = do
  let (JWT.ClaimsMap Map Text Value
claimsMap) = JWTClaimsSet -> ClaimsMap
JWT.unregisteredClaims (JWTClaimsSet -> ClaimsMap) -> JWTClaimsSet -> ClaimsMap
forall a b. (a -> b) -> a -> b
$ JWT r -> JWTClaimsSet
forall r. JWT r -> JWTClaimsSet
JWT.claims JWT r
token
      maybeClaim :: Maybe Value
maybeClaim = Text -> Map Text Value -> Maybe Value
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Text
"icepeak" Map Text Value
claimsMap
  Value
claimJson <- Either String Value
-> (Value -> Either String Value)
-> Maybe Value
-> Either String Value
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (String -> Either String Value
forall a b. a -> Either a b
Left String
"Icepeak claim missing.") Value -> Either String Value
forall a b. b -> Either a b
Right Maybe Value
maybeClaim
  (Value -> Parser IcepeakClaim)
-> Value -> Either String IcepeakClaim
forall a b. (a -> Parser b) -> a -> Either String b
Aeson.parseEither Value -> Parser IcepeakClaim
forall a. FromJSON a => Value -> Parser a
Aeson.parseJSON Value
claimJson

-- * Token generation

-- | Add the icepeak claim to a set of JWT claims.
addIcepeakClaim :: IcepeakClaim -> JWT.JWTClaimsSet -> JWT.JWTClaimsSet
addIcepeakClaim :: IcepeakClaim -> JWTClaimsSet -> JWTClaimsSet
addIcepeakClaim IcepeakClaim
claim JWTClaimsSet
claims = JWTClaimsSet
claims
  { unregisteredClaims :: ClaimsMap
JWT.unregisteredClaims = ClaimsMap
newClaimsMap ClaimsMap -> ClaimsMap -> ClaimsMap
forall a. Semigroup a => a -> a -> a
<> JWTClaimsSet -> ClaimsMap
JWT.unregisteredClaims JWTClaimsSet
claims }
    where
      newClaimsMap :: ClaimsMap
newClaimsMap = Map Text Value -> ClaimsMap
JWT.ClaimsMap (Map Text Value -> ClaimsMap) -> Map Text Value -> ClaimsMap
forall a b. (a -> b) -> a -> b
$ [(Text, Value)] -> Map Text Value
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [(Text
"icepeak", IcepeakClaim -> Value
forall a. ToJSON a => a -> Value
Aeson.toJSON IcepeakClaim
claim)]