{-# LANGUAGE MagicHash, UnboxedTuples, CApiFFI, UnliftedFFITypes, BangPatterns #-}

-------------------------------------------------------------------------------
-- |
-- Module:      Crypto.Sha256.Subtle
-- Copyright:   (c) 2024 Auth Global
-- License:     Apache2
--
-------------------------------------------------------------------------------

module Crypto.Sha256.Subtle where

import           Data.Array.Byte
import           Data.Bits((.&.))
import           Data.ByteString (ByteString)
import qualified Data.ByteString as B
import           Data.ByteString.Unsafe(unsafeUseAsCStringLen)
import           Data.Word
import           Foreign.C
import           Foreign.Ptr
import           GHC.Exts
import           GHC.IO

import           Crypto.HashString
import           Crypto.HashString.FFI(HashString(..))

nullBuffer :: ByteString
nullBuffer :: ByteString
nullBuffer = Int -> Word8 -> ByteString
B.replicate Int
64 Word8
0

type MutableSha256State# = MutableByteArray#

type Sha256State# = ByteArray#

type MutableSha256Ctx# = MutableByteArray#

type Sha256Ctx# = ByteArray#

newtype Sha256State = Sha256State { Sha256State -> ByteArray
unSha256State :: ByteArray }

instance Eq Sha256State where
  Sha256State
x == :: Sha256State -> Sha256State -> Bool
== Sha256State
y = Sha256State -> Sha256State -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Sha256State
x Sha256State
y Ordering -> Ordering -> Bool
forall a. Eq a => a -> a -> Bool
== Ordering
EQ

instance Ord Sha256State where
  compare :: Sha256State -> Sha256State -> Ordering
compare (Sha256State (ByteArray ByteArray#
x)) (Sha256State (ByteArray ByteArray#
y)) =
    CInt -> CInt -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (ByteArray# -> ByteArray# -> Word32 -> CInt
c_const_memcmp_uint32be ByteArray#
x ByteArray#
y Word32
8) CInt
0

newtype Sha256Ctx = Sha256Ctx { Sha256Ctx -> ByteArray
unSha256Ctx :: ByteArray }

instance Eq Sha256Ctx where
  Sha256Ctx
x == :: Sha256Ctx -> Sha256Ctx -> Bool
== Sha256Ctx
y = Sha256Ctx -> Sha256Ctx -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Sha256Ctx
x Sha256Ctx
y Ordering -> Ordering -> Bool
forall a. Eq a => a -> a -> Bool
== Ordering
EQ

instance Ord Sha256Ctx where
  compare :: Sha256Ctx -> Sha256Ctx -> Ordering
compare (Sha256Ctx (ByteArray ByteArray#
x)) (Sha256Ctx (ByteArray ByteArray#
y)) =
    CInt -> CInt -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (ByteArray# -> ByteArray# -> CInt
c_const_memcmp_ctx ByteArray#
x ByteArray#
y) CInt
0

sha256state_init :: Sha256State
sha256state_init :: Sha256State
sha256state_init =
  IO Sha256State -> Sha256State
forall a. IO a -> a
unsafePerformIO (IO Sha256State -> Sha256State)
-> ((State# RealWorld -> (# State# RealWorld, Sha256State #))
    -> IO Sha256State)
-> (State# RealWorld -> (# State# RealWorld, Sha256State #))
-> Sha256State
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (State# RealWorld -> (# State# RealWorld, Sha256State #))
-> IO Sha256State
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO ((State# RealWorld -> (# State# RealWorld, Sha256State #))
 -> Sha256State)
-> (State# RealWorld -> (# State# RealWorld, Sha256State #))
-> Sha256State
forall a b. (a -> b) -> a -> b
$ \State# RealWorld
st ->
    let !(Ptr Addr#
addr) = Ptr Word32
c_sha256_init
        !(# State# RealWorld
st0, MutableByteArray# RealWorld
a #) = Int#
-> State# RealWorld
-> (# State# RealWorld, MutableByteArray# RealWorld #)
forall d. Int# -> State# d -> (# State# d, MutableByteArray# d #)
newByteArray# Int#
32# State# RealWorld
st
        st1 :: State# RealWorld
st1 = Addr#
-> MutableByteArray# RealWorld
-> Int#
-> Int#
-> State# RealWorld
-> State# RealWorld
forall d.
Addr#
-> MutableByteArray# d -> Int# -> Int# -> State# d -> State# d
copyAddrToByteArray# Addr#
addr MutableByteArray# RealWorld
a Int#
0# Int#
32# State# RealWorld
st0
        -- FIXME?  Review this to ensure that 32# is the correct input above
        -- Problem is the documentation is ambiguous, and the source is magic.
        -- I'm assuming copyAddrToByteArray# works similarly as copyByteArray#.
        !(# State# RealWorld
st2, ByteArray#
b #) = MutableByteArray# RealWorld
-> State# RealWorld -> (# State# RealWorld, ByteArray# #)
forall d.
MutableByteArray# d -> State# d -> (# State# d, ByteArray# #)
unsafeFreezeByteArray# MutableByteArray# RealWorld
a State# RealWorld
st1
     in (# State# RealWorld
st2, (ByteArray -> Sha256State
Sha256State (ByteArray# -> ByteArray
ByteArray ByteArray#
b)) #)

-- | Note that this function only processes as many 64-byte blocks as possible,
--   then discards the remainder of the input.  Also note that this function does
--   nothing to track the number of bytes that have been fed into the state, which
--   will have to be done externally.

sha256state_feed :: ByteString -> Sha256State -> Sha256State
sha256state_feed :: ByteString -> Sha256State -> Sha256State
sha256state_feed ByteString
bytes (Sha256State (ByteArray ByteArray#
p)) =
  IO Sha256State -> Sha256State
forall a. IO a -> a
unsafePerformIO (IO Sha256State -> Sha256State)
-> ((CStringLen -> IO Sha256State) -> IO Sha256State)
-> (CStringLen -> IO Sha256State)
-> Sha256State
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> (CStringLen -> IO Sha256State) -> IO Sha256State
forall a. ByteString -> (CStringLen -> IO a) -> IO a
unsafeUseAsCStringLen ByteString
bytes ((CStringLen -> IO Sha256State) -> Sha256State)
-> (CStringLen -> IO Sha256State) -> Sha256State
forall a b. (a -> b) -> a -> b
$ \(Ptr CChar
bp, Int
bl) -> (State# RealWorld -> (# State# RealWorld, Sha256State #))
-> IO Sha256State
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO ((State# RealWorld -> (# State# RealWorld, Sha256State #))
 -> IO Sha256State)
-> (State# RealWorld -> (# State# RealWorld, Sha256State #))
-> IO Sha256State
forall a b. (a -> b) -> a -> b
$ \State# RealWorld
st ->
    let !(# State# RealWorld
st0, MutableByteArray# RealWorld
a #) = Int#
-> State# RealWorld
-> (# State# RealWorld, MutableByteArray# RealWorld #)
forall d. Int# -> State# d -> (# State# d, MutableByteArray# d #)
newByteArray# Int#
32# State# RealWorld
st
        !(# State# RealWorld
st1, Word64
_ #) = IO Word64 -> State# RealWorld -> (# State# RealWorld, Word64 #)
forall a. IO a -> State# RealWorld -> (# State# RealWorld, a #)
unIO (ByteArray#
-> Ptr CChar -> CSize -> MutableByteArray# RealWorld -> IO Word64
c_sha256_update ByteArray#
p Ptr CChar
bp (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
bl) MutableByteArray# RealWorld
a) State# RealWorld
st0
        !(# State# RealWorld
st2, ByteArray#
b #) = MutableByteArray# RealWorld
-> State# RealWorld -> (# State# RealWorld, ByteArray# #)
forall d.
MutableByteArray# d -> State# d -> (# State# d, ByteArray# #)
unsafeFreezeByteArray# MutableByteArray# RealWorld
a State# RealWorld
st1
     in (# State# RealWorld
st2, ByteArray -> Sha256State
Sha256State (ByteArray# -> ByteArray
ByteArray ByteArray#
b) #)

-- | Cast a Sha256Ctx to a Sha256State, without (much, if any) copying.
--   This has the disadvantage that the result will retain at least 8, and up to
--   71 unnecessary bytes, depending on the length of the buffer.  72 extra bytes
--   will likely be possible once this binding supports mutable contexts and
--   supports freezing mutable contexts into immutable contexts without copying.

sha256state_fromCtxInplace :: Sha256Ctx -> Sha256State
sha256state_fromCtxInplace :: Sha256Ctx -> Sha256State
sha256state_fromCtxInplace (Sha256Ctx ByteArray
a) = ByteArray -> Sha256State
Sha256State ByteArray
a

-- | Cast a Sha256Ctx to a Sha256State. This copies the first 32 bytes of the
--   Sha256Ctx structure, so the result is always as small as possible.

sha256state_fromCtx :: Sha256Ctx -> Sha256State
sha256state_fromCtx :: Sha256Ctx -> Sha256State
sha256state_fromCtx (Sha256Ctx (ByteArray ByteArray#
ctx)) =
  IO Sha256State -> Sha256State
forall a. IO a -> a
unsafePerformIO (IO Sha256State -> Sha256State)
-> ((State# RealWorld -> (# State# RealWorld, Sha256State #))
    -> IO Sha256State)
-> (State# RealWorld -> (# State# RealWorld, Sha256State #))
-> Sha256State
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (State# RealWorld -> (# State# RealWorld, Sha256State #))
-> IO Sha256State
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO ((State# RealWorld -> (# State# RealWorld, Sha256State #))
 -> Sha256State)
-> (State# RealWorld -> (# State# RealWorld, Sha256State #))
-> Sha256State
forall a b. (a -> b) -> a -> b
$ \State# RealWorld
st ->
    let !(# State# RealWorld
st0, MutableByteArray# RealWorld
a #) = Int#
-> State# RealWorld
-> (# State# RealWorld, MutableByteArray# RealWorld #)
forall d. Int# -> State# d -> (# State# d, MutableByteArray# d #)
newByteArray# Int#
32# State# RealWorld
st
        st1 :: State# RealWorld
st1 = ByteArray#
-> Int#
-> MutableByteArray# RealWorld
-> Int#
-> Int#
-> State# RealWorld
-> State# RealWorld
forall d.
ByteArray#
-> Int#
-> MutableByteArray# d
-> Int#
-> Int#
-> State# d
-> State# d
copyByteArray# ByteArray#
ctx Int#
0# MutableByteArray# RealWorld
a Int#
0# Int#
32# State# RealWorld
st0
        !(# State# RealWorld
st2, ByteArray#
b #) = MutableByteArray# RealWorld
-> State# RealWorld -> (# State# RealWorld, ByteArray# #)
forall d.
MutableByteArray# d -> State# d -> (# State# d, ByteArray# #)
unsafeFreezeByteArray# MutableByteArray# RealWorld
a State# RealWorld
st1
     in (# State# RealWorld
st2, ByteArray -> Sha256State
Sha256State (ByteArray# -> ByteArray
ByteArray ByteArray#
b) #)

sha256state_runWith :: Word64 -> ByteString -> Sha256State -> Sha256Ctx
sha256state_runWith :: Word64 -> ByteString -> Sha256State -> Sha256Ctx
sha256state_runWith Word64
blocks ByteString
bytes (Sha256State (ByteArray ByteArray#
p)) =
    IO Sha256Ctx -> Sha256Ctx
forall a. IO a -> a
unsafePerformIO (IO Sha256Ctx -> Sha256Ctx)
-> ((CStringLen -> IO Sha256Ctx) -> IO Sha256Ctx)
-> (CStringLen -> IO Sha256Ctx)
-> Sha256Ctx
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> (CStringLen -> IO Sha256Ctx) -> IO Sha256Ctx
forall a. ByteString -> (CStringLen -> IO a) -> IO a
unsafeUseAsCStringLen ByteString
bytes ((CStringLen -> IO Sha256Ctx) -> Sha256Ctx)
-> (CStringLen -> IO Sha256Ctx) -> Sha256Ctx
forall a b. (a -> b) -> a -> b
$ \(Ptr CChar
bp, Int
bl) -> (State# RealWorld -> (# State# RealWorld, Sha256Ctx #))
-> IO Sha256Ctx
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO ((State# RealWorld -> (# State# RealWorld, Sha256Ctx #))
 -> IO Sha256Ctx)
-> (State# RealWorld -> (# State# RealWorld, Sha256Ctx #))
-> IO Sha256Ctx
forall a b. (a -> b) -> a -> b
$ \State# RealWorld
st ->
      let !(# State# RealWorld
st0, MutableByteArray# RealWorld
a #) = Int#
-> State# RealWorld
-> (# State# RealWorld, MutableByteArray# RealWorld #)
forall d. Int# -> State# d -> (# State# d, MutableByteArray# d #)
newByteArray# Int#
ctxLen# State# RealWorld
st
          !(# State# RealWorld
st1, () #) = IO () -> State# RealWorld -> (# State# RealWorld, () #)
forall a. IO a -> State# RealWorld -> (# State# RealWorld, a #)
unIO (ByteArray#
-> Word64
-> Ptr CChar
-> CSize
-> MutableByteArray# RealWorld
-> IO ()
c_sha256_promote_to_ctx ByteArray#
p Word64
blocks Ptr CChar
bp (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
bl) MutableByteArray# RealWorld
a) State# RealWorld
st0
          !(# State# RealWorld
st2, ByteArray#
b #) = MutableByteArray# RealWorld
-> State# RealWorld -> (# State# RealWorld, ByteArray# #)
forall d.
MutableByteArray# d -> State# d -> (# State# d, ByteArray# #)
unsafeFreezeByteArray# MutableByteArray# RealWorld
a State# RealWorld
st1
       in (# State# RealWorld
st2, ByteArray -> Sha256Ctx
Sha256Ctx (ByteArray# -> ByteArray
ByteArray ByteArray#
b) #)
  where
    !(I# Int#
ctxLen#) = Int
40 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ ByteString -> Int
B.length ByteString
bytes Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
0x3F

sha256state_encode :: Sha256State -> HashString
sha256state_encode :: Sha256State -> HashString
sha256state_encode (Sha256State (ByteArray ByteArray#
x)) =
    IO HashString -> HashString
forall a. IO a -> a
unsafePerformIO (IO HashString -> HashString)
-> ((State# RealWorld -> (# State# RealWorld, HashString #))
    -> IO HashString)
-> (State# RealWorld -> (# State# RealWorld, HashString #))
-> HashString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (State# RealWorld -> (# State# RealWorld, HashString #))
-> IO HashString
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO ((State# RealWorld -> (# State# RealWorld, HashString #))
 -> HashString)
-> (State# RealWorld -> (# State# RealWorld, HashString #))
-> HashString
forall a b. (a -> b) -> a -> b
$ \State# RealWorld
st ->
      let !(# State# RealWorld
st0, MutableByteArray# RealWorld
a #) = Int#
-> State# RealWorld
-> (# State# RealWorld, MutableByteArray# RealWorld #)
forall d. Int# -> State# d -> (# State# d, MutableByteArray# d #)
newByteArray# Int#
32# State# RealWorld
st
          !(# State# RealWorld
st1, () #) = IO () -> State# RealWorld -> (# State# RealWorld, () #)
forall a. IO a -> State# RealWorld -> (# State# RealWorld, a #)
unIO (ByteArray# -> MutableByteArray# RealWorld -> IO ()
c_sha256_encode_state ByteArray#
x MutableByteArray# RealWorld
a) State# RealWorld
st0
          !(# State# RealWorld
st2, ByteArray#
b #) = MutableByteArray# RealWorld
-> State# RealWorld -> (# State# RealWorld, ByteArray# #)
forall d.
MutableByteArray# d -> State# d -> (# State# d, ByteArray# #)
unsafeFreezeByteArray# MutableByteArray# RealWorld
a State# RealWorld
st1
       in (# State# RealWorld
st2, ByteArray -> HashString
HashString (ByteArray# -> ByteArray
ByteArray ByteArray#
b) #)

sha256state_decode :: HashString -> Sha256State
sha256state_decode :: HashString -> Sha256State
sha256state_decode (HashString (ByteArray ByteArray#
x)) =
    IO Sha256State -> Sha256State
forall a. IO a -> a
unsafePerformIO (IO Sha256State -> Sha256State)
-> ((State# RealWorld -> (# State# RealWorld, Sha256State #))
    -> IO Sha256State)
-> (State# RealWorld -> (# State# RealWorld, Sha256State #))
-> Sha256State
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (State# RealWorld -> (# State# RealWorld, Sha256State #))
-> IO Sha256State
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO ((State# RealWorld -> (# State# RealWorld, Sha256State #))
 -> Sha256State)
-> (State# RealWorld -> (# State# RealWorld, Sha256State #))
-> Sha256State
forall a b. (a -> b) -> a -> b
$ \State# RealWorld
st ->
      let !(# State# RealWorld
st0, MutableByteArray# RealWorld
a #) = Int#
-> State# RealWorld
-> (# State# RealWorld, MutableByteArray# RealWorld #)
forall d. Int# -> State# d -> (# State# d, MutableByteArray# d #)
newByteArray# Int#
32# State# RealWorld
st
          !(# State# RealWorld
st1, () #) = IO () -> State# RealWorld -> (# State# RealWorld, () #)
forall a. IO a -> State# RealWorld -> (# State# RealWorld, a #)
unIO (ByteArray# -> MutableByteArray# RealWorld -> IO ()
c_sha256_decode_state ByteArray#
x MutableByteArray# RealWorld
a) State# RealWorld
st0
          !(# State# RealWorld
st2, ByteArray#
b #) = MutableByteArray# RealWorld
-> State# RealWorld -> (# State# RealWorld, ByteArray# #)
forall d.
MutableByteArray# d -> State# d -> (# State# d, ByteArray# #)
unsafeFreezeByteArray# MutableByteArray# RealWorld
a State# RealWorld
st1
       in (# State# RealWorld
st2, ByteArray -> Sha256State
Sha256State (ByteArray# -> ByteArray
ByteArray ByteArray#
b) #)

-- these calls must be labelled "unsafe", because the datastructures
-- we will be passing in are unpinned... keep that in mind when selecting
-- the size of the updates.  (Also, maybe in some cases a different FFI
-- layer that uses safe calls would be desirable?  Maybe not... It seems like
-- it should be possible to adequately work around the limitations of
-- long-lived unsafe calls by using smaller updates, making more calls to C.)

-- See the documentation for details:
-- https://ghc.gitlab.haskell.org/ghc/doc/users_guide/exts/ffi.html#guaranteed-call-safety

-- TODO: some functions have more than one binding, mostly for type reasons, and there
-- are several more variants of some of these bindings this module should support.

foreign import ccall unsafe "hs_sha256.h &hs_sha256_init"
    c_sha256_init :: Ptr Word32

foreign import capi unsafe "hs_sha256.h hs_sha256_init_ctx"
    c_sha256_init_ctx :: MutableSha256Ctx# RealWorld -> IO ()

foreign import capi unsafe "hs_sha256.h hs_sha256_promote_to_ctx"
  c_sha256_promote_to_ctx
    :: Sha256State# -- ^ @state@, a pointer to an constant array of eight Word32
    -> Word64 -- ^ @blockCount@, the number of blocks that a sha256 context has processed
    -> CString -- ^ pointer to the constant data to process
    -> CSize -- ^ length of the data to process
    -> MutableSha256Ctx# RealWorld -- ^ output pointer
    -> IO ()

foreign import capi unsafe "hs_sha256.h hs_sha256_update"
  c_sha256_update
    :: Sha256State# -- ^ @state@, a pointer to an constant array of eight Word32
    -> CString -- ^ pointer to the constant data to process
    -> CSize -- ^ length of the data to process
    -> MutableSha256State# RealWorld -- ^ output pointer
    -> IO Word64 -- ^ the new @count@

foreign import capi unsafe "hs_sha256.h hs_sha256_update_ctx"
  c_sha256_update_ctx
    :: Sha256Ctx# -- ^ @ctx@, a pointer to a constant sha256 context
    -> CString -- ^ pointer to the constant data to process
    -> CSize -- ^ length of the data to process
    -> MutableSha256Ctx# RealWorld -- ^ output pointer
    -> IO ()

foreign import capi unsafe "hs_sha256.h hs_sha256_update_ctx"
  c_sha256_mutate_ctx
    :: MutableSha256Ctx# RealWorld -- ^ @ctx@, a pointer to a constant sha256 context
    -> CString -- ^ pointer to the constant data to process
    -> CSize -- ^ length of the data to process
    -> MutableSha256Ctx# RealWorld -- ^ output pointer, can be same as the input context
    -> IO ()

foreign import capi unsafe "hs_sha256.h hs_sha256_get_count"
  c_sha256_get_count
    :: Sha256State#
    -> Word64

foreign import capi unsafe "hs_sha256.h hs_sha256_finalize_ctx_bits"
  c_sha256_finalize_ctx_bits
    :: Sha256Ctx#
    -> CString
    -> Word64
    -> Ptr Word8
    -> IO ()

foreign import capi unsafe "hs_sha256.h hs_sha256_finalize_ctx_bits"
  c_sha256_finalize_ctx_bits_ba
    :: Sha256Ctx#
    -> CString
    -> Word64
    -> MutableByteArray# RealWorld
    -> IO ()

foreign import capi unsafe "hs_sha256.h hs_sha256_finalize_ctx_bits"
  c_sha256_finalize_mutable_ctx_bits
    :: MutableSha256Ctx# RealWorld
    -> CString
    -> Word64
    -> CString
    -> IO ()

foreign import capi unsafe "hs_sha256.h hs_sha256_encode_state"
  c_sha256_encode_state
    :: Sha256State#
    -> MutableByteArray# RealWorld
    -> IO ()

foreign import capi unsafe "hs_sha256.h hs_sha256_decode_state"
  c_sha256_decode_state
    :: ByteArray#
    -> MutableSha256State# RealWorld
    -> IO ()

foreign import capi unsafe "hs_sha256.h hs_sha256_const_memcmp_uint32be"
  c_const_memcmp_uint32be
    :: ByteArray#
    -> ByteArray#
    -> Word32
    -> CInt

foreign import capi unsafe "hs_sha256.h hs_sha256_const_memcmp_ctx"
  c_const_memcmp_ctx
    :: ByteArray#
    -> ByteArray#
    -> CInt