-- Copyright (C) 2014-2022  Fraser Tweedale
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--      http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TemplateHaskell #-}

{-|

JOSE error types and helpers.

-}
module Crypto.JOSE.Error
  (
  -- * Running JOSE computations
    runJOSE
  , unwrapJOSE
  , JOSE(..)

  -- * Base error type and class
  , Error(..)
  , AsError(..)

  -- * JOSE compact serialisation errors
  , InvalidNumberOfParts(..), expectedParts, actualParts
  , CompactTextError(..)
  , CompactDecodeError(..)
  , _CompactInvalidNumberOfParts
  , _CompactInvalidText

  ) where

import Data.Semigroup ((<>))
import Numeric.Natural

import Control.Monad.Except
import Control.Monad.Trans
import qualified Crypto.PubKey.RSA as RSA
import Crypto.Error (CryptoError)
import Crypto.Random (MonadRandom(..))
import Control.Lens (Getter, to)
import Control.Lens.TH (makeClassyPrisms, makePrisms)
import qualified Data.Text as T
import qualified Data.Text.Encoding.Error as T


-- | The wrong number of parts were found when decoding a
-- compact JOSE object.
--
data InvalidNumberOfParts =
  InvalidNumberOfParts Natural Natural -- ^ expected vs actual parts
  deriving (InvalidNumberOfParts -> InvalidNumberOfParts -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: InvalidNumberOfParts -> InvalidNumberOfParts -> Bool
$c/= :: InvalidNumberOfParts -> InvalidNumberOfParts -> Bool
== :: InvalidNumberOfParts -> InvalidNumberOfParts -> Bool
$c== :: InvalidNumberOfParts -> InvalidNumberOfParts -> Bool
Eq)

instance Show InvalidNumberOfParts where
  show :: InvalidNumberOfParts -> String
show (InvalidNumberOfParts Natural
n Natural
m) =
    String
"Expected " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Natural
n forall a. Semigroup a => a -> a -> a
<> String
" parts; got " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Natural
m

-- | Get the expected or actual number of parts.
expectedParts, actualParts :: Getter InvalidNumberOfParts Natural
expectedParts :: Getter InvalidNumberOfParts Natural
expectedParts = forall (p :: * -> * -> *) (f :: * -> *) s a.
(Profunctor p, Contravariant f) =>
(s -> a) -> Optic' p f s a
to forall a b. (a -> b) -> a -> b
$ \(InvalidNumberOfParts Natural
n Natural
_) -> Natural
n
actualParts :: Getter InvalidNumberOfParts Natural
actualParts   = forall (p :: * -> * -> *) (f :: * -> *) s a.
(Profunctor p, Contravariant f) =>
(s -> a) -> Optic' p f s a
to forall a b. (a -> b) -> a -> b
$ \(InvalidNumberOfParts Natural
_ Natural
n) -> Natural
n


-- | Bad UTF-8 data in a compact object, at the specified index
data CompactTextError = CompactTextError
  Natural
  T.UnicodeException
  deriving (CompactTextError -> CompactTextError -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: CompactTextError -> CompactTextError -> Bool
$c/= :: CompactTextError -> CompactTextError -> Bool
== :: CompactTextError -> CompactTextError -> Bool
$c== :: CompactTextError -> CompactTextError -> Bool
Eq)

instance Show CompactTextError where
  show :: CompactTextError -> String
show (CompactTextError Natural
n UnicodeException
s) =
    String
"Invalid text at part " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Natural
n forall a. Semigroup a => a -> a -> a
<> String
": " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show UnicodeException
s


-- | An error when decoding a JOSE compact object.
-- JSON decoding errors that occur during compact object processing
-- throw 'JSONDecodeError'.
--
data CompactDecodeError
  = CompactInvalidNumberOfParts InvalidNumberOfParts
  | CompactInvalidText CompactTextError
  deriving (CompactDecodeError -> CompactDecodeError -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: CompactDecodeError -> CompactDecodeError -> Bool
$c/= :: CompactDecodeError -> CompactDecodeError -> Bool
== :: CompactDecodeError -> CompactDecodeError -> Bool
$c== :: CompactDecodeError -> CompactDecodeError -> Bool
Eq)
makePrisms ''CompactDecodeError

instance Show CompactDecodeError where
  show :: CompactDecodeError -> String
show (CompactInvalidNumberOfParts InvalidNumberOfParts
e) = String
"Invalid number of parts: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show InvalidNumberOfParts
e
  show (CompactInvalidText CompactTextError
e) = String
"Invalid text: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show CompactTextError
e



-- | All the errors that can occur.
--
data Error
  = AlgorithmNotImplemented   -- ^ A requested algorithm is not implemented
  | AlgorithmMismatch String  -- ^ A requested algorithm cannot be used
  | KeyMismatch T.Text        -- ^ Wrong type of key was given
  | KeySizeTooSmall           -- ^ Key size is too small
  | OtherPrimesNotSupported   -- ^ RSA private key with >2 primes not supported
  | RSAError RSA.Error        -- ^ RSA encryption, decryption or signing error
  | CryptoError CryptoError   -- ^ Various cryptonite library error cases
  | CompactDecodeError CompactDecodeError
  -- ^ Wrong number of parts in compact serialisation
  | JSONDecodeError String    -- ^ JSON (Aeson) decoding error
  | NoUsableKeys              -- ^ No usable keys were found in the key store
  | JWSCritUnprotected
  | JWSNoValidSignatures
  -- ^ 'AnyValidated' policy active, and no valid signature encountered
  | JWSInvalidSignature
  -- ^ 'AllValidated' policy active, and invalid signature encountered
  | JWSNoSignatures
  -- ^ 'AllValidated' policy active, and there were no signatures on object
  --   that matched the allowed algorithms
  deriving (Error -> Error -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Error -> Error -> Bool
$c/= :: Error -> Error -> Bool
== :: Error -> Error -> Bool
$c== :: Error -> Error -> Bool
Eq, Int -> Error -> ShowS
[Error] -> ShowS
Error -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Error] -> ShowS
$cshowList :: [Error] -> ShowS
show :: Error -> String
$cshow :: Error -> String
showsPrec :: Int -> Error -> ShowS
$cshowsPrec :: Int -> Error -> ShowS
Show)
makeClassyPrisms ''Error


newtype JOSE e m a = JOSE (ExceptT e m a)

-- | Run the 'JOSE' computation.  Result is an @Either e a@
-- where @e@ is the error type (typically 'Error' or 'Crypto.JWT.JWTError')
runJOSE :: JOSE e m a -> m (Either e a)
runJOSE :: forall e (m :: * -> *) a. JOSE e m a -> m (Either e a)
runJOSE = forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT forall b c a. (b -> c) -> (a -> b) -> a -> c
. (\(JOSE ExceptT e m a
m) -> ExceptT e m a
m)

-- | Get the inner 'ExceptT' value of the 'JOSE' computation.
-- Typically 'runJOSE' would be preferred, unless you specifically
-- need an 'ExceptT' value.
unwrapJOSE :: JOSE e m a -> ExceptT e m a
unwrapJOSE :: forall e (m :: * -> *) a. JOSE e m a -> ExceptT e m a
unwrapJOSE (JOSE ExceptT e m a
m) = ExceptT e m a
m


instance (Functor m) => Functor (JOSE e m) where
  fmap :: forall a b. (a -> b) -> JOSE e m a -> JOSE e m b
fmap a -> b
f (JOSE ExceptT e m a
ma) = forall e (m :: * -> *) a. ExceptT e m a -> JOSE e m a
JOSE (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f ExceptT e m a
ma)

instance (Monad m) => Applicative (JOSE e m) where
  pure :: forall a. a -> JOSE e m a
pure = forall e (m :: * -> *) a. ExceptT e m a -> JOSE e m a
JOSE forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a. Applicative f => a -> f a
pure
  JOSE ExceptT e m (a -> b)
mf <*> :: forall a b. JOSE e m (a -> b) -> JOSE e m a -> JOSE e m b
<*> JOSE ExceptT e m a
ma = forall e (m :: * -> *) a. ExceptT e m a -> JOSE e m a
JOSE (ExceptT e m (a -> b)
mf forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ExceptT e m a
ma)

instance (Monad m) => Monad (JOSE e m) where
  JOSE ExceptT e m a
ma >>= :: forall a b. JOSE e m a -> (a -> JOSE e m b) -> JOSE e m b
>>= a -> JOSE e m b
f = forall e (m :: * -> *) a. ExceptT e m a -> JOSE e m a
JOSE (ExceptT e m a
ma forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall e (m :: * -> *) a. JOSE e m a -> ExceptT e m a
unwrapJOSE forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> JOSE e m b
f)

instance MonadTrans (JOSE e) where
  lift :: forall (m :: * -> *) a. Monad m => m a -> JOSE e m a
lift = forall e (m :: * -> *) a. ExceptT e m a -> JOSE e m a
JOSE forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift

instance (Monad m) => MonadError e (JOSE e m) where
  throwError :: forall a. e -> JOSE e m a
throwError = forall e (m :: * -> *) a. ExceptT e m a -> JOSE e m a
JOSE forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError
  catchError :: forall a. JOSE e m a -> (e -> JOSE e m a) -> JOSE e m a
catchError (JOSE ExceptT e m a
m) e -> JOSE e m a
handle = forall e (m :: * -> *) a. ExceptT e m a -> JOSE e m a
JOSE (forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
catchError ExceptT e m a
m (forall e (m :: * -> *) a. JOSE e m a -> ExceptT e m a
unwrapJOSE forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> JOSE e m a
handle))

instance (MonadIO m) => MonadIO (JOSE e m) where
  liftIO :: forall a. IO a -> JOSE e m a
liftIO = forall e (m :: * -> *) a. ExceptT e m a -> JOSE e m a
JOSE forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO

instance (MonadRandom m) => MonadRandom (JOSE e m) where
    getRandomBytes :: forall byteArray. ByteArray byteArray => Int -> JOSE e m byteArray
getRandomBytes = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
getRandomBytes