module Crypto.Cipher.ChaChaPoly1305
( State
, Nonce
, nonce12
, nonce8
, incrementNonce
, initialize
, appendAAD
, finalizeAAD
, encrypt
, decrypt
, finalize
) where
import Control.Monad (when)
import Crypto.Internal.ByteArray (ByteArrayAccess, ByteArray, Bytes, ScrubbedBytes)
import qualified Crypto.Internal.ByteArray as B
import Crypto.Internal.Imports
import Crypto.Error
import qualified Crypto.Cipher.ChaCha as ChaCha
import qualified Crypto.MAC.Poly1305 as Poly1305
import Data.Memory.Endian
import qualified Data.ByteArray.Pack as P
import Foreign.Ptr
import Foreign.Storable
data State = State !ChaCha.State
!Poly1305.State
!Word64
!Word64
data Nonce = Nonce8 Bytes | Nonce12 Bytes
instance ByteArrayAccess Nonce where
length (Nonce8 n) = B.length n
length (Nonce12 n) = B.length n
withByteArray (Nonce8 n) = B.withByteArray n
withByteArray (Nonce12 n) = B.withByteArray n
pad16 :: Word64 -> Bytes
pad16 n
| modLen == 0 = B.empty
| otherwise = B.replicate (16 modLen) 0
where
modLen = fromIntegral (n `mod` 16)
nonce12 :: ByteArrayAccess iv => iv -> CryptoFailable Nonce
nonce12 iv
| B.length iv /= 12 = CryptoFailed CryptoError_IvSizeInvalid
| otherwise = CryptoPassed . Nonce12 . B.convert $ iv
nonce8 :: ByteArrayAccess ba
=> ba
-> ba
-> CryptoFailable Nonce
nonce8 constant iv
| B.length constant /= 4 = CryptoFailed CryptoError_IvSizeInvalid
| B.length iv /= 8 = CryptoFailed CryptoError_IvSizeInvalid
| otherwise = CryptoPassed . Nonce8 . B.concat $ [constant, iv]
incrementNonce :: Nonce -> Nonce
incrementNonce (Nonce8 n) = Nonce8 $ incrementNonce' n 4
incrementNonce (Nonce12 n) = Nonce12 $ incrementNonce' n 0
incrementNonce' :: Bytes -> Int -> Bytes
incrementNonce' b offset = B.copyAndFreeze b $ \s ->
loop s (s `plusPtr` offset)
where
loop :: Ptr Word8 -> Ptr Word8 -> IO ()
loop s p
| s == (p `plusPtr` (B.length b offset 1)) = peek s >>= poke s . (+) 1
| otherwise = do
r <- (+) 1 <$> peek p
poke p r
when (r == 0) $ loop s (p `plusPtr` 1)
initialize :: ByteArrayAccess key
=> key -> Nonce -> CryptoFailable State
initialize key (Nonce8 nonce) = initialize' key nonce
initialize key (Nonce12 nonce) = initialize' key nonce
initialize' :: ByteArrayAccess key
=> key -> Bytes -> CryptoFailable State
initialize' key nonce
| B.length key /= 32 = CryptoFailed CryptoError_KeySizeInvalid
| otherwise = CryptoPassed $ State encState polyState 0 0
where
rootState = ChaCha.initialize 20 key nonce
(polyKey, encState) = ChaCha.generate rootState 64
polyState = throwCryptoError $ Poly1305.initialize (B.take 32 polyKey :: ScrubbedBytes)
appendAAD :: ByteArrayAccess ba => ba -> State -> State
appendAAD ba (State encState macState aadLength plainLength) =
State encState newMacState newLength plainLength
where
newMacState = Poly1305.update macState ba
newLength = aadLength + fromIntegral (B.length ba)
finalizeAAD :: State -> State
finalizeAAD (State encState macState aadLength plainLength) =
State encState newMacState aadLength plainLength
where
newMacState = Poly1305.update macState $ pad16 aadLength
encrypt :: ByteArray ba => ba -> State -> (ba, State)
encrypt input (State encState macState aadLength plainLength) =
(output, State newEncState newMacState aadLength newPlainLength)
where
(output, newEncState) = ChaCha.combine encState input
newMacState = Poly1305.update macState output
newPlainLength = plainLength + fromIntegral (B.length input)
decrypt :: ByteArray ba => ba -> State -> (ba, State)
decrypt input (State encState macState aadLength plainLength) =
(output, State newEncState newMacState aadLength newPlainLength)
where
(output, newEncState) = ChaCha.combine encState input
newMacState = Poly1305.update macState input
newPlainLength = plainLength + fromIntegral (B.length input)
finalize :: State -> Poly1305.Auth
finalize (State _ macState aadLength plainLength) =
Poly1305.finalize $ Poly1305.updates macState
[ pad16 plainLength
, either (error "finalize: internal error") id $ P.fill 16 (P.putStorable (toLE aadLength) >> P.putStorable (toLE plainLength))
]