{-# LANGUAGE ConstraintKinds             #-}
{-# LANGUAGE MultiParamTypeClasses       #-}
{-# LANGUAGE FlexibleInstances           #-}
{-# LANGUAGE TypeFamilies                #-}
{-# LANGUAGE DataKinds                   #-}
{-# LANGUAGE FlexibleContexts            #-}
{-# LANGUAGE RecordWildCards             #-}
-- | An implementation for simple MAC which is based on a
-- cryptographic hash. This construction is safe only for certain
-- hashes like blake2 and therefore should not be used
-- indiscriminately. In particular, sha2 hashes should not be used in
-- this mode as they are prone to length extension attack.
--
-- If you want to use sha2 hashs for message authentication, you
-- should make use of the more complicated HMAC construction.
--
module Mac.Implementation
          ( Prim
          , name
          , description
          , Internals
          , BufferAlignment
          , BufferPtr
          , processBlocks
          , processLast
          , additionalBlocks
          , Key (..)
          ) where

import           Data.ByteString       as BS
import           Raaz.Core
import           Raaz.Core.Transfer.Unsafe
import           Raaz.Primitive.Keyed.Internal
import qualified Implementation        as Base
import qualified Utils                 as U
import qualified Buffer                as B

type Prim = Keyed Base.Prim

-- | Name of the implementation.
name :: String
name :: String
name = String
Base.name String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"-keyed-hash"

-- | Description of the implementation.
description :: String
description :: String
description = String
"Implementation of a MAC based on simple keyed hashing that makes use of "
              String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
Base.name
              String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" implementation."

type BufferAlignment = Base.BufferAlignment
type BufferPtr       = AlignedBlockPtr BufferAlignment Prim

toKeyedBlocks :: BlockCount Base.Prim -> BlockCount Prim
toKeyedBlocks :: BlockCount Prim -> BlockCount Prim
toKeyedBlocks = Int -> BlockCount Prim
forall a. Enum a => Int -> a
toEnum (Int -> BlockCount Prim)
-> (BlockCount Prim -> Int) -> BlockCount Prim -> BlockCount Prim
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BlockCount Prim -> Int
forall a. Enum a => a -> Int
fromEnum

fromKeyedBlocks :: BlockCount Prim -> BlockCount Base.Prim
fromKeyedBlocks :: BlockCount Prim -> BlockCount Prim
fromKeyedBlocks = Int -> BlockCount Prim
forall a. Enum a => Int -> a
toEnum (Int -> BlockCount Prim)
-> (BlockCount Prim -> Int) -> BlockCount Prim -> BlockCount Prim
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BlockCount Prim -> Int
forall a. Enum a => a -> Int
fromEnum

-- | The additional space required in the buffer for processing the data.
additionalBlocks :: BlockCount Prim
additionalBlocks :: BlockCount Prim
additionalBlocks = BlockCount Prim -> BlockCount Prim
toKeyedBlocks BlockCount Prim
Base.additionalBlocks

trim ::  Key (Keyed Base.Prim) -> BS.ByteString
trim :: Key Prim -> ByteString
trim (Key hKey) = Int -> ByteString -> ByteString
BS.take Int
sz ByteString
hKey
  where sz :: Int
sz = BYTES Int -> Int
forall a. Enum a => a -> Int
fromEnum (BYTES Int -> Int) -> BYTES Int -> Int
forall a b. (a -> b) -> a -> b
$ Proxy Prim -> BYTES Int
forall a. Storable a => Proxy a -> BYTES Int
sizeOf (Proxy Prim
forall k (t :: k). Proxy t
Proxy :: Proxy Base.Prim)


-- | The internal memory used by the implementation.
data Internals = MACInternals { Internals -> Internals
hashInternals    :: Base.Internals
                              , Internals -> Buffer 1
keyBuffer        :: B.Buffer 1
                              , Internals -> MemoryCell Bool
atStart          :: MemoryCell Bool
                                -- Flag to check whether the key has been processed or not.
                                -- see the note on Delayed key processing
                              }

-- | Process the key inside the buffer with the process Buffer
-- function.
processKey :: Internals -> IO ()
processKey :: Internals -> IO ()
processKey MACInternals{MemoryCell Bool
Internals
Buffer 1
atStart :: MemoryCell Bool
keyBuffer :: Buffer 1
hashInternals :: Internals
atStart :: Internals -> MemoryCell Bool
keyBuffer :: Internals -> Buffer 1
hashInternals :: Internals -> Internals
..} = Buffer 1 -> Internals -> IO ()
forall (n :: Nat). KnownNat n => Buffer n -> Internals -> IO ()
U.processBuffer Buffer 1
keyBuffer Internals
hashInternals


-- | Process the key in the buffer with the processLast function.
processKeyLast :: Internals -> IO ()
processKeyLast :: Internals -> IO ()
processKeyLast MACInternals{MemoryCell Bool
Internals
Buffer 1
atStart :: MemoryCell Bool
keyBuffer :: Buffer 1
hashInternals :: Internals
atStart :: Internals -> MemoryCell Bool
keyBuffer :: Internals -> Buffer 1
hashInternals :: Internals -> Internals
..} = BufferPtr -> BYTES Int -> Internals -> IO ()
Base.processLast BufferPtr
bufPtr BYTES Int
bufsz Internals
hashInternals
  where bufPtr :: BufferPtr
bufPtr = Buffer 1 -> BufferPtr
forall (n :: Nat). Buffer n -> BufferPtr
B.unsafeGetBufferPointer Buffer 1
keyBuffer
        bufsz :: BYTES Int
bufsz  = BlockCount Prim -> BYTES Int
forall u. LengthUnit u => u -> BYTES Int
inBytes (BlockCount Prim -> BYTES Int) -> BlockCount Prim -> BYTES Int
forall a b. (a -> b) -> a -> b
$ Int -> Proxy Prim -> BlockCount Prim
forall p. Int -> Proxy p -> BlockCount p
blocksOf Int
1 (Proxy Prim
forall k (t :: k). Proxy t
Proxy :: Proxy Base.Prim)


instance Memory Internals where
  memoryAlloc :: Alloc Internals
memoryAlloc = Internals -> Buffer 1 -> MemoryCell Bool -> Internals
MACInternals (Internals -> Buffer 1 -> MemoryCell Bool -> Internals)
-> TwistRF AllocField (BYTES Int) Internals
-> TwistRF
     AllocField (BYTES Int) (Buffer 1 -> MemoryCell Bool -> Internals)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TwistRF AllocField (BYTES Int) Internals
forall m. Memory m => Alloc m
memoryAlloc TwistRF
  AllocField (BYTES Int) (Buffer 1 -> MemoryCell Bool -> Internals)
-> TwistRF AllocField (BYTES Int) (Buffer 1)
-> TwistRF AllocField (BYTES Int) (MemoryCell Bool -> Internals)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TwistRF AllocField (BYTES Int) (Buffer 1)
forall m. Memory m => Alloc m
memoryAlloc TwistRF AllocField (BYTES Int) (MemoryCell Bool -> Internals)
-> TwistRF AllocField (BYTES Int) (MemoryCell Bool)
-> Alloc Internals
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TwistRF AllocField (BYTES Int) (MemoryCell Bool)
forall m. Memory m => Alloc m
memoryAlloc
  unsafeToPointer :: Internals -> Ptr Word8
unsafeToPointer = Internals -> Ptr Word8
forall m. Memory m => m -> Ptr Word8
unsafeToPointer (Internals -> Ptr Word8)
-> (Internals -> Internals) -> Internals -> Ptr Word8
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Internals -> Internals
hashInternals

-- * Delayed key processing ::
--
-- It would look like the initialisation step is pretty straight
-- forward. Write the padded key to the buffer and then run process
-- blocks on it. This will work as long as the message that needs to
-- be authenticated is at-least 1 byte long.
--
-- For null bytes the padded key block is the last block and hashes
-- like blake2 pass a different finalisation flag for the last
-- block. At initialisation we cannot predict whether the message we
-- are about to see is empty or not. So we keep everything ready
-- (i.e. write the key into the keybuffer) and mark a flag that says
-- we are at the start of the message processing. The first time we
-- call processBlocks or processLast, will have to do the appropriate
-- initialisation and then proceed from there on.

instance Initialisable Internals (Key (Keyed Base.Prim)) where
  initialise :: Key Prim -> Internals -> IO ()
initialise Key Prim
hKey Internals
imem
    = do Prim -> Internals -> IO ()
forall m v. Initialisable m v => v -> m -> IO ()
initialise Prim
hash0 (Internals -> IO ()) -> Internals -> IO ()
forall a b. (a -> b) -> a -> b
$ Internals -> Internals
hashInternals Internals
imem
         Buffer 1 -> IO ()
writeKeyIntoBuffer (Buffer 1 -> IO ()) -> Buffer 1 -> IO ()
forall a b. (a -> b) -> a -> b
$ Internals -> Buffer 1
keyBuffer Internals
imem
         Bool -> MemoryCell Bool -> IO ()
forall m v. Initialisable m v => v -> m -> IO ()
initialise Bool
True (MemoryCell Bool -> IO ()) -> MemoryCell Bool -> IO ()
forall a b. (a -> b) -> a -> b
$ Internals -> MemoryCell Bool
atStart Internals
imem
           where kbs :: ByteString
kbs        = Key Prim -> ByteString
trim Key Prim
hKey
                 hash0      :: Base.Prim
                 hash0 :: Prim
hash0      = BYTES Int -> Prim
forall prim. KeyedHash prim => BYTES Int -> prim
hashInit (BYTES Int -> Prim) -> BYTES Int -> Prim
forall a b. (a -> b) -> a -> b
$ ByteString -> BYTES Int
Raaz.Core.length ByteString
kbs
                 keyWrite :: WriteTo
keyWrite   = Word8 -> BlockCount Prim -> WriteTo -> WriteTo
forall n. LengthUnit n => Word8 -> n -> WriteTo -> WriteTo
padWrite Word8
0 (Int -> Proxy Prim -> BlockCount Prim
forall p. Int -> Proxy p -> BlockCount p
blocksOf Int
1 Proxy Prim
proxyPrim) (WriteTo -> WriteTo) -> WriteTo -> WriteTo
forall a b. (a -> b) -> a -> b
$ ByteString -> WriteTo
writeByteString ByteString
kbs

                 writeKeyIntoBuffer :: Buffer 1 -> IO ()
writeKeyIntoBuffer = WriteTo
-> AlignedPtr BufferAlignment (Tuple 16 (LE Word32)) -> IO ()
forall (ptr :: * -> *) (t :: Mode) a.
Pointer ptr =>
Transfer t -> ptr a -> IO ()
unsafeTransfer WriteTo
keyWrite (AlignedPtr BufferAlignment (Tuple 16 (LE Word32)) -> IO ())
-> (Buffer 1 -> AlignedPtr BufferAlignment (Tuple 16 (LE Word32)))
-> Buffer 1
-> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Buffer 1 -> AlignedPtr BufferAlignment (Tuple 16 (LE Word32))
forall (n :: Nat). Buffer n -> BufferPtr
B.unsafeGetBufferPointer
                 proxyPrim :: Proxy Prim
proxyPrim = Proxy Prim
forall k (t :: k). Proxy t
Proxy :: Proxy Base.Prim

instance Extractable Internals Prim where
  extract :: Internals -> IO Prim
extract = (Prim -> Prim) -> IO Prim -> IO Prim
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Prim -> Prim
forall prim. prim -> Keyed prim
unsafeToKeyed (IO Prim -> IO Prim)
-> (Internals -> IO Prim) -> Internals -> IO Prim
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Internals -> IO Prim
forall m v. Extractable m v => m -> IO v
extract (Internals -> IO Prim)
-> (Internals -> Internals) -> Internals -> IO Prim
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Internals -> Internals
hashInternals


-- | The function that process bytes in multiples of the block size of
-- the primitive.
processBlocks :: BufferPtr
              -> BlockCount Prim
              -> Internals
              -> IO ()
processBlocks :: BufferPtr -> BlockCount Prim -> Internals -> IO ()
processBlocks BufferPtr
aptr BlockCount Prim
blks imem :: Internals
imem@MACInternals{MemoryCell Bool
Internals
Buffer 1
atStart :: MemoryCell Bool
keyBuffer :: Buffer 1
hashInternals :: Internals
atStart :: Internals -> MemoryCell Bool
keyBuffer :: Internals -> Buffer 1
hashInternals :: Internals -> Internals
..} = do
  Bool
start <- MemoryCell Bool -> IO Bool
forall m v. Extractable m v => m -> IO v
extract MemoryCell Bool
atStart
  Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
start (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do Internals -> IO ()
processKey Internals
imem
                  Bool -> MemoryCell Bool -> IO ()
forall m v. Initialisable m v => v -> m -> IO ()
initialise Bool
False MemoryCell Bool
atStart
  BufferPtr -> BlockCount Prim -> Internals -> IO ()
Base.processBlocks (AlignedPtr BufferAlignment (Tuple 16 (LE Word32))
-> AlignedPtr BufferAlignment (Tuple 16 (LE Word32))
forall (ptr :: * -> *) a b. Pointer ptr => ptr a -> ptr b
castPointer AlignedPtr BufferAlignment (Tuple 16 (LE Word32))
BufferPtr
aptr) (BlockCount Prim -> BlockCount Prim
fromKeyedBlocks BlockCount Prim
blks) Internals
hashInternals

-- | Process the last bytes of the stream.
processLast :: BufferPtr
            -> BYTES Int
            -> Internals
            -> IO ()
processLast :: BufferPtr -> BYTES Int -> Internals -> IO ()
processLast BufferPtr
aptr BYTES Int
sz imem :: Internals
imem@MACInternals{MemoryCell Bool
Internals
Buffer 1
atStart :: MemoryCell Bool
keyBuffer :: Buffer 1
hashInternals :: Internals
atStart :: Internals -> MemoryCell Bool
keyBuffer :: Internals -> Buffer 1
hashInternals :: Internals -> Internals
..} = do
  Bool
start <- MemoryCell Bool -> IO Bool
forall m v. Extractable m v => m -> IO v
extract MemoryCell Bool
atStart
  if Bool
start Bool -> Bool -> Bool
&& BYTES Int
sz BYTES Int -> BYTES Int -> Bool
forall a. Eq a => a -> a -> Bool
== BYTES Int
0 then Internals -> IO ()
processKeyLast Internals
imem
    else do Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
start (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Internals -> IO ()
processKey Internals
imem
            BufferPtr -> BYTES Int -> Internals -> IO ()
Base.processLast (AlignedPtr BufferAlignment (Tuple 16 (LE Word32))
-> AlignedPtr BufferAlignment (Tuple 16 (LE Word32))
forall (ptr :: * -> *) a b. Pointer ptr => ptr a -> ptr b
castPointer AlignedPtr BufferAlignment (Tuple 16 (LE Word32))
BufferPtr
aptr) BYTES Int
sz Internals
hashInternals