{-# LANGUAGE BangPatterns #-}

-------------------------------------------------------------------------------
-- |
-- Module:      Crypto.Sha256.Pbkdf2
-- Copyright:   (c) 2024 Auth Global
-- License:     Apache2
--
-- An implementation of PBKDF2-HMAC-SHA256
--
-------------------------------------------------------------------------------

module Crypto.Sha256.Pbkdf2
     ( pbkdf2
     , pbkdf2_index
     , Pbkdf2Ctx()
     , pbkdf2Ctx_init
     , pbkdf2Ctx_feed, pbkdf2Ctx_feeds
     , pbkdf2Ctx_update, pbkdf2Ctx_updates
     , pbkdf2Ctx_finalize
     , Pbkdf2Gen()
     , pbkdf2Gen_iterate
     , pbkdf2Gen_finalize
     )
     where

import           Data.ByteString(ByteString)
import qualified Data.ByteString.Short as SB
import           Data.Function((&))
import           Data.Word
import           Crypto.HashString ( HashString )
import qualified Crypto.HashString as HS
import           Crypto.Sha256
import           Crypto.Sha256.Hmac
import           Crypto.Sha256.Hmac.Subtle
import           Crypto.Sha256.Pbkdf2.Subtle
import qualified Network.ByteOrder as NB

takeHS :: Int -> [ HashString ] -> [ HashString ]
takeHS :: Int -> [HashString] -> [HashString]
takeHS = Int -> [HashString] -> [HashString]
go
  where
    len :: HashString -> Int
len = ShortByteString -> Int
SB.length (ShortByteString -> Int)
-> (HashString -> ShortByteString) -> HashString -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HashString -> ShortByteString
HS.toShort
    go :: Int -> [HashString] -> [HashString]
go Int
_ [] = []
    go Int
n (HashString
b:[HashString]
bs)
      | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0 = []
      | HashString -> Int
len HashString
b Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n = HashString
b HashString -> [HashString] -> [HashString]
forall a. a -> [a] -> [a]
: Int -> [HashString] -> [HashString]
go (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- HashString -> Int
len HashString
b) [HashString]
bs
      | Bool
otherwise = [ShortByteString -> HashString
HS.fromShort (Int -> ShortByteString -> ShortByteString
SB.take Int
n (HashString -> ShortByteString
HS.toShort HashString
b))]

-- | Simple interface to PBKDF2. Reusing computations via partial application is
--   not (yet!) supported. TODO: write pbkdf2 and pbkdf2_index functions in a
--   point-free style.

pbkdf2
  :: ByteString -- ^ nominally the "password"
  -> ByteString -- ^ nominally the "salt"
  -> Word64 -- ^ number of rounds
  -> Int -- ^ desired length of output
  -> HashString
pbkdf2 :: ByteString -> ByteString -> Word64 -> Int -> HashString
pbkdf2 ByteString
password0 ByteString
salt Word64
rounds Int
len = HashString
out
  where
     password :: HmacKeyHashed
password = ByteString -> HmacKeyHashed
hmacKeyHashed ByteString
password0
     saltCtx :: Pbkdf2Ctx
saltCtx =
       HmacKeyHashed -> Pbkdf2Ctx
pbkdf2Ctx_init HmacKeyHashed
password Pbkdf2Ctx -> (Pbkdf2Ctx -> Pbkdf2Ctx) -> Pbkdf2Ctx
forall a b. a -> (a -> b) -> b
&
       ByteString -> Pbkdf2Ctx -> Pbkdf2Ctx
pbkdf2Ctx_feed ByteString
salt
     gen :: Word32 -> HashString
gen Word32
index =
       Word32 -> Pbkdf2Ctx -> Pbkdf2Gen
pbkdf2Ctx_finalize Word32
index Pbkdf2Ctx
saltCtx Pbkdf2Gen -> (Pbkdf2Gen -> Pbkdf2Gen) -> Pbkdf2Gen
forall a b. a -> (a -> b) -> b
&
       Word64 -> Pbkdf2Gen -> Pbkdf2Gen
pbkdf2Gen_iterate ((Word64 -> Word64 -> Word64
forall a. Ord a => a -> a -> a
max Word64
rounds Word64
1) Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
1) Pbkdf2Gen -> (Pbkdf2Gen -> HashString) -> HashString
forall a b. a -> (a -> b) -> b
&
       Pbkdf2Gen -> HashString
pbkdf2Gen_finalize
     out :: HashString
out = [HashString] -> HashString
forall a. Monoid a => [a] -> a
mconcat (Int -> [HashString] -> [HashString]
takeHS Int
len ((Word32 -> HashString) -> [Word32] -> [HashString]
forall a b. (a -> b) -> [a] -> [b]
map Word32 -> HashString
gen [Word32
1..Word32
forall a. Bounded a => a
maxBound]))

pbkdf2_index
  :: ByteString -- ^ nominally the "password"
  -> ByteString -- ^ nominally the "salt"
  -> Word32 -- ^ the "index", returns the i-th block of output. The first index is 1, thus the result consists of bytes starting at 32*(i-1) and ending before 32*i.  This is appended as 4 more bytes after the salt.
  -> Word64 -- ^ number of rounds
  -> HashString -- ^ 32-byte output"
pbkdf2_index :: ByteString -> ByteString -> Word32 -> Word64 -> HashString
pbkdf2_index ByteString
password0 ByteString
salt Word32
index Word64
rounds = HashString
out
  where
     password :: HmacKeyHashed
password = ByteString -> HmacKeyHashed
hmacKeyHashed ByteString
password0
     saltCtx :: Pbkdf2Ctx
saltCtx =
       HmacKeyHashed -> Pbkdf2Ctx
pbkdf2Ctx_init HmacKeyHashed
password Pbkdf2Ctx -> (Pbkdf2Ctx -> Pbkdf2Ctx) -> Pbkdf2Ctx
forall a b. a -> (a -> b) -> b
&
       ByteString -> Pbkdf2Ctx -> Pbkdf2Ctx
pbkdf2Ctx_feed ByteString
salt
     out :: HashString
out =
       Word32 -> Pbkdf2Ctx -> Pbkdf2Gen
pbkdf2Ctx_finalize Word32
index Pbkdf2Ctx
saltCtx Pbkdf2Gen -> (Pbkdf2Gen -> Pbkdf2Gen) -> Pbkdf2Gen
forall a b. a -> (a -> b) -> b
&
       Word64 -> Pbkdf2Gen -> Pbkdf2Gen
pbkdf2Gen_iterate ((Word64 -> Word64 -> Word64
forall a. Ord a => a -> a -> a
max Word64
rounds Word64
1) Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
1) Pbkdf2Gen -> (Pbkdf2Gen -> HashString) -> HashString
forall a b. a -> (a -> b) -> b
&
       Pbkdf2Gen -> HashString
pbkdf2Gen_finalize

pbkdf2Ctx_init :: HmacKeyHashed -> Pbkdf2Ctx
pbkdf2Ctx_init :: HmacKeyHashed -> Pbkdf2Ctx
pbkdf2Ctx_init HmacKeyHashed
password = Pbkdf2Ctx
    { pbkdf2Ctx_password :: HmacKeyHashed
pbkdf2Ctx_password = HmacKeyHashed
password
    , pbkdf2Ctx_ipadCtx :: Sha256Ctx
pbkdf2Ctx_ipadCtx = HmacKeyHashed -> Sha256Ctx
hmacKeyHashed_ipadCtx HmacKeyHashed
password
    }

-- | Append some bytes to the end of the salt. Flipped version of 'pbkdf2Ctx_feed'.

pbkdf2Ctx_update :: Pbkdf2Ctx -> ByteString -> Pbkdf2Ctx
pbkdf2Ctx_update :: Pbkdf2Ctx -> ByteString -> Pbkdf2Ctx
pbkdf2Ctx_update Pbkdf2Ctx
ctx ByteString
bs = Pbkdf2Ctx
ctx { pbkdf2Ctx_ipadCtx = sha256_update (pbkdf2Ctx_ipadCtx ctx) bs }

-- | Append zero or more bytestrings to the end of the salt. Flipped version of 'pbkdf2Ctx_feeds'

pbkdf2Ctx_updates :: Foldable f => Pbkdf2Ctx -> f ByteString -> Pbkdf2Ctx
pbkdf2Ctx_updates :: forall (f :: * -> *).
Foldable f =>
Pbkdf2Ctx -> f ByteString -> Pbkdf2Ctx
pbkdf2Ctx_updates Pbkdf2Ctx
ctx f ByteString
bs = Pbkdf2Ctx
ctx { pbkdf2Ctx_ipadCtx = sha256_updates (pbkdf2Ctx_ipadCtx ctx) bs }

-- | Append some bytes to the end of the salt. Flipped version of 'pbkdf2Ctx_update'.

pbkdf2Ctx_feed :: ByteString -> Pbkdf2Ctx -> Pbkdf2Ctx
pbkdf2Ctx_feed :: ByteString -> Pbkdf2Ctx -> Pbkdf2Ctx
pbkdf2Ctx_feed = (Pbkdf2Ctx -> ByteString -> Pbkdf2Ctx)
-> ByteString -> Pbkdf2Ctx -> Pbkdf2Ctx
forall a b c. (a -> b -> c) -> b -> a -> c
flip Pbkdf2Ctx -> ByteString -> Pbkdf2Ctx
pbkdf2Ctx_update

-- | Append zero or more bytestrings to the end of the salt. Flipped version of 'pbkdf2Ctx_updates'.

pbkdf2Ctx_feeds :: Foldable f => f ByteString -> Pbkdf2Ctx ->  Pbkdf2Ctx
pbkdf2Ctx_feeds :: forall (f :: * -> *).
Foldable f =>
f ByteString -> Pbkdf2Ctx -> Pbkdf2Ctx
pbkdf2Ctx_feeds = (Pbkdf2Ctx -> f ByteString -> Pbkdf2Ctx)
-> f ByteString -> Pbkdf2Ctx -> Pbkdf2Ctx
forall a b c. (a -> b -> c) -> b -> a -> c
flip Pbkdf2Ctx -> f ByteString -> Pbkdf2Ctx
forall (f :: * -> *).
Foldable f =>
Pbkdf2Ctx -> f ByteString -> Pbkdf2Ctx
pbkdf2Ctx_updates

-- | Append the index to the end of the salt, and then initialize a 'Pbkdf2Gen' with
--   one round applied.

pbkdf2Ctx_finalize
  :: Word32 -- ^ index of output block
  -> Pbkdf2Ctx
  -> Pbkdf2Gen
pbkdf2Ctx_finalize :: Word32 -> Pbkdf2Ctx -> Pbkdf2Gen
pbkdf2Ctx_finalize Word32
index Pbkdf2Ctx
ctx = Pbkdf2Gen
  { pbkdf2Gen_password :: HmacKeyHashed
pbkdf2Gen_password = HmacKeyHashed
password
  , pbkdf2Gen_finalize :: HashString
pbkdf2Gen_finalize = HashString
state
  , pbkdf2Gen_state :: HashString
pbkdf2Gen_state = HashString
state
  }
  where
    password :: HmacKeyHashed
password = Pbkdf2Ctx -> HmacKeyHashed
pbkdf2Ctx_password Pbkdf2Ctx
ctx
    ipad :: ByteString
ipad = Pbkdf2Ctx -> Sha256Ctx
pbkdf2Ctx_ipadCtx Pbkdf2Ctx
ctx Sha256Ctx -> (Sha256Ctx -> ByteString) -> ByteString
forall a b. a -> (a -> b) -> b
&
           ByteString -> Sha256Ctx -> ByteString
sha256_finalizeBytes_toByteString (Word32 -> ByteString
NB.bytestring32 Word32
index)
    state :: HashString
state = HmacKeyHashed -> Sha256Ctx
hmacKeyHashed_opadCtx HmacKeyHashed
password Sha256Ctx -> (Sha256Ctx -> HashString) -> HashString
forall a b. a -> (a -> b) -> b
&
            ByteString -> Sha256Ctx -> HashString
sha256_finalizeBytes ByteString
ipad

-- | Apply zero or more rounds to a pbkdf2 computation.

pbkdf2Gen_iterate
  :: Word64  -- ^ number of key-stretching rounds to perform
  -> Pbkdf2Gen
  -> Pbkdf2Gen
pbkdf2Gen_iterate :: Word64 -> Pbkdf2Gen -> Pbkdf2Gen
pbkdf2Gen_iterate Word64
n0 Pbkdf2Gen
ctx = Word64 -> HashString -> HashString -> Pbkdf2Gen
forall {t}.
(Ord t, Num t) =>
t -> HashString -> HashString -> Pbkdf2Gen
go Word64
n0 HashString
xorSum0 HashString
state0
  where
    password :: HmacKeyHashed
password = Pbkdf2Gen -> HmacKeyHashed
pbkdf2Gen_password Pbkdf2Gen
ctx
    xorSum0 :: HashString
xorSum0 = Pbkdf2Gen -> HashString
pbkdf2Gen_finalize Pbkdf2Gen
ctx
    state0 :: HashString
state0 = Pbkdf2Gen -> HashString
pbkdf2Gen_state Pbkdf2Gen
ctx
    go :: t -> HashString -> HashString -> Pbkdf2Gen
go t
n HashString
xorSum HashString
state
      | t
n t -> t -> Bool
forall a. Ord a => a -> a -> Bool
<= t
0 =
        Pbkdf2Gen
          { pbkdf2Gen_password :: HmacKeyHashed
pbkdf2Gen_password = HmacKeyHashed
password
          , pbkdf2Gen_finalize :: HashString
pbkdf2Gen_finalize = HashString
xorSum
          , pbkdf2Gen_state :: HashString
pbkdf2Gen_state = HashString
state
          }
      | Bool
otherwise =
        let !state' :: HashString
state' = HmacKeyHashed -> HmacCtx
hmacKeyHashed_run HmacKeyHashed
password HmacCtx -> (HmacCtx -> HashString) -> HashString
forall a b. a -> (a -> b) -> b
&
                      ByteString -> HmacCtx -> HashString
hmacCtx_finalizeBytes (HashString -> ByteString
HS.toByteString HashString
state)
            !xorSum' :: HashString
xorSum' = HashString -> HashString -> HashString
HS.xorLeft HashString
state' HashString
xorSum
         in t -> HashString -> HashString -> Pbkdf2Gen
go (t
nt -> t -> t
forall a. Num a => a -> a -> a
-t
1) HashString
xorSum' HashString
state'