{-# LANGUAGE BangPatterns #-}
module Crypto.Hash.Keccak
(
keccak224
, keccak256
, keccak384
, keccak512
, sha3_512
, sha3_384
, sha3_256
, sha3_224
, shake128
, shake256
, keccakHash
, sha3Hash
, paddingKeccak
, paddingSha3
, absorb
, squeeze
) where
import Data.Bits
import qualified Data.ByteString as BS
import qualified Data.ByteString.Builder as BS
import qualified Data.ByteString.Lazy as LBS
import Data.Foldable
import Data.Monoid
import Data.Vector.Unboxed ((!))
import qualified Data.Vector.Unboxed as V
import qualified Data.Vector.Unboxed.Mutable as MV
import Data.Word
import Prelude hiding (pi)
rounds :: Int
rounds = 24
numLanes :: Int
numLanes = 25
laneWidth :: Int
laneWidth = 64
emptyState :: V.Vector Word64
emptyState = V.replicate numLanes 0
roundConstants :: V.Vector Word64
roundConstants = V.fromList [ 0x0000000000000001, 0x0000000000008082, 0x800000000000808A
, 0x8000000080008000, 0x000000000000808B, 0x0000000080000001
, 0x8000000080008081, 0x8000000000008009, 0x000000000000008A
, 0x0000000000000088, 0x0000000080008009, 0x000000008000000A
, 0x000000008000808B, 0x800000000000008B, 0x8000000000008089
, 0x8000000000008003, 0x8000000000008002, 0x8000000000000080
, 0x000000000000800A, 0x800000008000000A, 0x8000000080008081
, 0x8000000000008080, 0x0000000080000001, 0x8000000080008008 ]
rotationConstants :: V.Vector Int
rotationConstants = V.fromList [ 0, 36, 3, 41, 18
, 1, 44, 10, 45, 2
, 62, 6, 43, 15, 61
, 28, 55, 25, 21, 56
, 27, 20, 39, 8, 14 ]
piConstants :: V.Vector Int
piConstants = V.fromList [ 0, 15, 5, 20, 10
, 6, 21, 11, 1, 16
, 12, 2, 17, 7, 22
, 18, 8, 23, 13, 3
, 24, 14, 4, 19, 9 ]
hashFunction :: (Int -> BS.ByteString -> V.Vector Word8) -> Int -> BS.ByteString -> BS.ByteString
hashFunction paddingFunction rate = squeeze rate outputBytes . absorb rate
. paddingFunction (div rate 8)
where outputBytes = div (1600 - rate) 16
keccakHash :: Int -> BS.ByteString -> BS.ByteString
keccakHash = hashFunction paddingKeccak
sha3Hash :: Int -> BS.ByteString -> BS.ByteString
sha3Hash = hashFunction paddingSha3
keccak512 :: BS.ByteString -> BS.ByteString
keccak512 = keccakHash 576
keccak384 :: BS.ByteString -> BS.ByteString
keccak384 = keccakHash 832
keccak256 :: BS.ByteString -> BS.ByteString
keccak256 = keccakHash 1088
keccak224 :: BS.ByteString -> BS.ByteString
keccak224 = keccakHash 1152
sha3_512 :: BS.ByteString -> BS.ByteString
sha3_512 = sha3Hash 576
sha3_384 :: BS.ByteString -> BS.ByteString
sha3_384 = sha3Hash 832
sha3_256 :: BS.ByteString -> BS.ByteString
sha3_256 = sha3Hash 1088
sha3_224 :: BS.ByteString -> BS.ByteString
sha3_224 = sha3Hash 1152
shakeFunction :: (Int -> BS.ByteString -> V.Vector Word8) -> Int
-> Int -> BS.ByteString -> BS.ByteString
shakeFunction paddingFunction rate outputBytes =
squeeze rate outputBytes . absorb rate
. paddingFunction (div rate 8)
shake128 :: Int -> BS.ByteString -> BS.ByteString
shake128 outputBits = shakeFunction paddingShake 1344 (div outputBits 8)
shake256 :: Int -> BS.ByteString -> BS.ByteString
shake256 outputBits = shakeFunction paddingShake 1088 (div outputBits 8)
multiratePadding :: Int -> Word8 -> BS.ByteString -> V.Vector Word8
multiratePadding bitrateBytes padByte input = V.generate totalLength process
where msglen = BS.length input
padlen = bitrateBytes - mod (BS.length input) bitrateBytes
totalLength = padlen + msglen
process x
| x < msglen = BS.index input x
| x == (totalLength - 1) && padlen == 1 = 0x80 .|. padByte
| x == (totalLength - 1) = 0x80
| x == msglen = padByte
| otherwise = 0x00
paddingKeccak :: Int -> BS.ByteString -> V.Vector Word8
paddingKeccak bitrateBytes = multiratePadding bitrateBytes 0x01
paddingSha3 :: Int -> BS.ByteString -> V.Vector Word8
paddingSha3 bitrateBytes = multiratePadding bitrateBytes 0x06
paddingShake :: Int -> BS.ByteString -> V.Vector Word8
paddingShake bitrateBytes = multiratePadding bitrateBytes 0x1F
toBlocks :: V.Vector Word8 -> V.Vector Word64
toBlocks = V.unfoldr toLane
where toLane :: V.Vector Word8 -> Maybe (Word64, V.Vector Word8)
toLane input
| V.null input = Nothing
| otherwise = let (h, t) = V.splitAt 8 input
in Just (V.ifoldl' createWord64 0 h, t)
createWord64 acc offset octet = acc `xor` shiftL (fromIntegral octet) (offset * 8)
absorb :: Int -> V.Vector Word8 -> V.Vector Word64
absorb rate = absorbBlock rate emptyState . toBlocks
absorbBlock :: Int -> V.Vector Word64 -> V.Vector Word64 -> V.Vector Word64
absorbBlock !rate !state !input
| V.null input = state
| otherwise = absorbBlock rate (keccakF state') (V.drop (div rate 64) input)
where state' = V.imap (\z el -> if div z 5 + 5 * mod z 5 < threshold
then el `xor` (input ! (div z 5 + 5 * mod z 5))
else el) state
threshold = div rate laneWidth
squeeze :: Int -> Int -> V.Vector Word64 -> BS.ByteString
squeeze !rate !l !state = BS.take l . LBS.toStrict . BS.toLazyByteString
. V.foldl' (\acc n -> acc <> BS.word64LE n) mempty
$ stateToBytes state
where lanesToExtract = ceiling $ fromIntegral l / fromIntegral (div laneWidth 8)
stateToBytes :: V.Vector Word64 -> V.Vector Word64
stateToBytes s = V.unfoldrN lanesToExtract extract (0, s)
threshold = div rate laneWidth
extract (x, s)
| x < threshold = Just (s ! (div x 5 + mod x 5 * 5), (succ x, s))
| otherwise = extract (0, keccakF s)
keccakF :: V.Vector Word64 -> V.Vector Word64
keccakF !state = snd $ foldl1 (.) (replicate rounds f) (0, state)
where f (!r, !s) = (succ r, iota r . chi . pi . rho $ theta s)
theta :: V.Vector Word64 -> V.Vector Word64
theta !state = V.concatMap (\(i, e) -> V.map (xor e) (V.slice (i * 5) 5 state)) $ V.indexed d
where c = V.generate 5 (\i -> V.foldl1' xor (V.slice (i * 5) 5 state))
d = V.generate 5 (\i -> c ! ((i - 1) `mod` 5) `xor` rotateL (c ! ((i + 1) `mod` 5)) 1)
{-# INLINE theta #-}
rho :: V.Vector Word64 -> V.Vector Word64
rho !state = V.zipWith (flip rotateL) rotationConstants state
{-# INLINE rho #-}
pi :: V.Vector Word64 -> V.Vector Word64
pi !state = V.backpermute state piConstants
{-# INLINE pi #-}
chi :: V.Vector Word64 -> V.Vector Word64
chi !b = V.imap subChi b
where subChi z el = el `xor` (complement (b ! mod (z + 5) 25) .&. (b ! mod (z + 10) 25))
{-# INLINE chi #-}
iota :: Int -> V.Vector Word64 -> V.Vector Word64
iota !roundNumber !state = V.modify (\v -> MV.write v 0 $ xor (roundConstants ! roundNumber) (V.head state)) state
{-# INLINE iota #-}