{-# LANGUAGE DeriveAnyClass       #-}
{-# LANGUAGE TypeOperators        #-}
{-# LANGUAGE UndecidableInstances #-}

module ZkFold.Symbolic.Algorithms.RSA
    ( sign
    , verify
    , signVar
    , verifyVar
    , RSA
    , PublicKey (..)
    , PrivateKey (..)
    , Signature
    ) where

import           Control.DeepSeq                      (NFData, force)
import           GHC.Generics                         (Generic)
import           Prelude                              (($))
import qualified Prelude                              as P

import           ZkFold.Base.Algebra.Basic.Number
import           ZkFold.Base.Data.Vector              (Vector)
import           ZkFold.Symbolic.Algorithms.Hash.SHA2 (SHA2, sha2, sha2Var)
import           ZkFold.Symbolic.Class
import           ZkFold.Symbolic.Data.Bool            (Bool, (&&))
import           ZkFold.Symbolic.Data.ByteString      (ByteString)
import           ZkFold.Symbolic.Data.Class
import           ZkFold.Symbolic.Data.Combinators     (Ceil, GetRegisterSize, Iso (..), KnownRegisters,
                                                       NumberOfRegisters, RegisterSize (..), Resize (..))
import           ZkFold.Symbolic.Data.Eq
import           ZkFold.Symbolic.Data.Input           (SymbolicInput, isValid)
import           ZkFold.Symbolic.Data.UInt            (OrdWord, UInt, expMod)
import           ZkFold.Symbolic.Data.VarByteString   (VarByteString)

type Signature keyLen ctx = ByteString keyLen ctx

data PrivateKey keyLen ctx
    = PrivateKey
        { forall (keyLen :: Natural) (ctx :: (Type -> Type) -> Type).
PrivateKey keyLen ctx -> UInt keyLen 'Auto ctx
prvD :: UInt keyLen 'Auto ctx
        , forall (keyLen :: Natural) (ctx :: (Type -> Type) -> Type).
PrivateKey keyLen ctx -> UInt keyLen 'Auto ctx
prvN :: UInt keyLen 'Auto ctx
        }

deriving instance Generic (PrivateKey keyLen context)
deriving instance (NFData (context (Vector (NumberOfRegisters (BaseField context) keyLen 'Auto)))) => NFData (PrivateKey keyLen context)
deriving instance (P.Eq (context (Vector (NumberOfRegisters (BaseField context) keyLen 'Auto))))   => P.Eq   (PrivateKey keyLen context)
deriving instance
    ( P.Show (BaseField context)
    , P.Show (context (Vector (NumberOfRegisters (BaseField context) keyLen 'Auto)))
    ) => P.Show (PrivateKey keyLen context)

deriving instance (Symbolic ctx, KnownRegisters ctx keyLen 'Auto) => SymbolicData (PrivateKey keyLen ctx)

instance
  ( Symbolic ctx
  , KnownNat keyLen
  , KnownRegisters ctx keyLen 'Auto
  ) => SymbolicInput (PrivateKey keyLen ctx) where
    isValid :: PrivateKey keyLen ctx -> Bool (Context (PrivateKey keyLen ctx))
isValid PrivateKey{UInt keyLen 'Auto ctx
prvD :: forall (keyLen :: Natural) (ctx :: (Type -> Type) -> Type).
PrivateKey keyLen ctx -> UInt keyLen 'Auto ctx
prvN :: forall (keyLen :: Natural) (ctx :: (Type -> Type) -> Type).
PrivateKey keyLen ctx -> UInt keyLen 'Auto ctx
prvD :: UInt keyLen 'Auto ctx
prvN :: UInt keyLen 'Auto ctx
..} = UInt keyLen 'Auto ctx -> Bool (Context (UInt keyLen 'Auto ctx))
forall d. SymbolicInput d => d -> Bool (Context d)
isValid UInt keyLen 'Auto ctx
prvD Bool ctx -> Bool ctx -> Bool ctx
forall b. BoolType b => b -> b -> b
&& UInt keyLen 'Auto ctx -> Bool (Context (UInt keyLen 'Auto ctx))
forall d. SymbolicInput d => d -> Bool (Context d)
isValid UInt keyLen 'Auto ctx
prvN

type PubExponentSize = 18

data PublicKey keyLen ctx
    = PublicKey
        { forall (keyLen :: Natural) (ctx :: (Type -> Type) -> Type).
PublicKey keyLen ctx -> UInt PubExponentSize 'Auto ctx
pubE :: UInt PubExponentSize 'Auto ctx
        , forall (keyLen :: Natural) (ctx :: (Type -> Type) -> Type).
PublicKey keyLen ctx -> UInt keyLen 'Auto ctx
pubN :: UInt keyLen 'Auto ctx
        }

deriving instance Generic (PublicKey keyLen context)
deriving instance
    ( NFData (context (Vector (NumberOfRegisters (BaseField context) keyLen 'Auto)))
    , NFData (context (Vector (NumberOfRegisters (BaseField context) PubExponentSize 'Auto)))
    ) =>  NFData  (PublicKey keyLen context)
deriving instance
    ( P.Eq (context (Vector (NumberOfRegisters (BaseField context) keyLen 'Auto)))
    , P.Eq (context (Vector (NumberOfRegisters (BaseField context) PubExponentSize 'Auto)))
    ) =>  P.Eq    (PublicKey keyLen context)
deriving instance
    ( P.Show (context (Vector (NumberOfRegisters (BaseField context) keyLen 'Auto)))
    , P.Show (context (Vector (NumberOfRegisters (BaseField context) PubExponentSize 'Auto)))
    , P.Show (BaseField context)
    ) =>  P.Show  (PublicKey keyLen context)

deriving instance
    ( Symbolic ctx
    , KnownRegisters ctx PubExponentSize 'Auto
    , KnownRegisters ctx keyLen 'Auto
    ) => SymbolicData (PublicKey keyLen ctx)

instance
  ( Symbolic ctx
  , KnownNat keyLen
  , KnownRegisters ctx PubExponentSize 'Auto
  , KnownRegisters ctx keyLen 'Auto
  ) => SymbolicInput (PublicKey keyLen ctx) where
    isValid :: PublicKey keyLen ctx -> Bool (Context (PublicKey keyLen ctx))
isValid PublicKey{UInt keyLen 'Auto ctx
UInt PubExponentSize 'Auto ctx
pubE :: forall (keyLen :: Natural) (ctx :: (Type -> Type) -> Type).
PublicKey keyLen ctx -> UInt PubExponentSize 'Auto ctx
pubN :: forall (keyLen :: Natural) (ctx :: (Type -> Type) -> Type).
PublicKey keyLen ctx -> UInt keyLen 'Auto ctx
pubE :: UInt PubExponentSize 'Auto ctx
pubN :: UInt keyLen 'Auto ctx
..} = UInt PubExponentSize 'Auto ctx
-> Bool (Context (UInt PubExponentSize 'Auto ctx))
forall d. SymbolicInput d => d -> Bool (Context d)
isValid UInt PubExponentSize 'Auto ctx
pubE Bool ctx -> Bool ctx -> Bool ctx
forall b. BoolType b => b -> b -> b
&& UInt keyLen 'Auto ctx -> Bool (Context (UInt keyLen 'Auto ctx))
forall d. SymbolicInput d => d -> Bool (Context d)
isValid UInt keyLen 'Auto ctx
pubN

type RSA keyLen msgLen ctx =
   ( SHA2 "SHA256" ctx msgLen
   , KnownNat keyLen
   , KnownNat (2 * keyLen)
   , KnownRegisters ctx keyLen 'Auto
   , KnownRegisters ctx (2 * keyLen) 'Auto
   , KnownNat (Ceil (GetRegisterSize (BaseField ctx) (2 * keyLen) 'Auto) OrdWord)
   , NFData (ctx (Vector keyLen))
   , NFData (ctx (Vector (NumberOfRegisters (BaseField ctx) keyLen 'Auto)))
   , NFData (ctx (Vector (NumberOfRegisters (BaseField ctx) (2 * keyLen) 'Auto)))
   )

sign
    :: forall keyLen msgLen ctx
    .  RSA keyLen msgLen ctx
    => ByteString msgLen ctx
    -> PrivateKey keyLen ctx
    -> Signature keyLen ctx
sign :: forall (keyLen :: Natural) (msgLen :: Natural)
       (ctx :: (Type -> Type) -> Type).
RSA keyLen msgLen ctx =>
ByteString msgLen ctx
-> PrivateKey keyLen ctx -> Signature keyLen ctx
sign ByteString msgLen ctx
msg PrivateKey{UInt keyLen 'Auto ctx
prvD :: forall (keyLen :: Natural) (ctx :: (Type -> Type) -> Type).
PrivateKey keyLen ctx -> UInt keyLen 'Auto ctx
prvN :: forall (keyLen :: Natural) (ctx :: (Type -> Type) -> Type).
PrivateKey keyLen ctx -> UInt keyLen 'Auto ctx
prvD :: UInt keyLen 'Auto ctx
prvN :: UInt keyLen 'Auto ctx
..} = Signature keyLen ctx -> Signature keyLen ctx
forall a. NFData a => a -> a
force (Signature keyLen ctx -> Signature keyLen ctx)
-> Signature keyLen ctx -> Signature keyLen ctx
forall a b. (a -> b) -> a -> b
$ UInt keyLen 'Auto ctx -> Signature keyLen ctx
forall a b. Iso a b => a -> b
from (UInt keyLen 'Auto ctx -> Signature keyLen ctx)
-> UInt keyLen 'Auto ctx -> Signature keyLen ctx
forall a b. (a -> b) -> a -> b
$ UInt 256 'Auto ctx
-> UInt keyLen 'Auto ctx
-> UInt keyLen 'Auto ctx
-> UInt keyLen 'Auto ctx
forall (c :: (Type -> Type) -> Type) (n :: Natural) (p :: Natural)
       (m :: Natural) (r :: RegisterSize).
(Symbolic c, KnownRegisterSize r, KnownNat p, KnownNat n,
 KnownNat m, KnownNat (2 * m), KnownRegisters c (2 * m) r,
 KnownNat (Ceil (GetRegisterSize (BaseField c) (2 * m) r) 16),
 NFData (c (Vector (NumberOfRegisters (BaseField c) (2 * m) r)))) =>
UInt n r c -> UInt p r c -> UInt m r c -> UInt m r c
expMod UInt 256 'Auto ctx
msgI UInt keyLen 'Auto ctx
prvD UInt keyLen 'Auto ctx
prvN
    where
        h :: ByteString 256 ctx
        h :: ByteString 256 ctx
h = forall (algorithm :: Symbol) (context :: (Type -> Type) -> Type)
       (k :: Natural) {d :: Natural}.
(SHA2 algorithm context k,
 d
 ~ Div
     (PaddedLength k (ChunkSize algorithm) (2 * WordSize algorithm))
     (ChunkSize algorithm)) =>
ByteString k context -> ByteString (ResultSize algorithm) context
sha2 @"SHA256" ByteString msgLen ctx
msg

        msgI :: UInt 256 'Auto ctx
        msgI :: UInt 256 'Auto ctx
msgI = ByteString 256 ctx -> UInt 256 'Auto ctx
forall a b. Iso a b => a -> b
from ByteString 256 ctx
h

verify
    :: forall keyLen msgLen ctx
    .  RSA keyLen msgLen ctx
    => ByteString msgLen ctx
    -> Signature keyLen ctx
    -> PublicKey keyLen ctx
    -> Bool ctx
verify :: forall (keyLen :: Natural) (msgLen :: Natural)
       (ctx :: (Type -> Type) -> Type).
RSA keyLen msgLen ctx =>
ByteString msgLen ctx
-> Signature keyLen ctx -> PublicKey keyLen ctx -> Bool ctx
verify ByteString msgLen ctx
msg Signature keyLen ctx
sig PublicKey{UInt keyLen 'Auto ctx
UInt PubExponentSize 'Auto ctx
pubE :: forall (keyLen :: Natural) (ctx :: (Type -> Type) -> Type).
PublicKey keyLen ctx -> UInt PubExponentSize 'Auto ctx
pubN :: forall (keyLen :: Natural) (ctx :: (Type -> Type) -> Type).
PublicKey keyLen ctx -> UInt keyLen 'Auto ctx
pubE :: UInt PubExponentSize 'Auto ctx
pubN :: UInt keyLen 'Auto ctx
..} = UInt keyLen 'Auto ctx
target UInt keyLen 'Auto ctx
-> UInt keyLen 'Auto ctx -> BooleanOf (UInt keyLen 'Auto ctx)
forall a. Eq a => a -> a -> BooleanOf a
== UInt keyLen 'Auto ctx
input
    where
        h :: ByteString 256 ctx
        h :: ByteString 256 ctx
h = forall (algorithm :: Symbol) (context :: (Type -> Type) -> Type)
       (k :: Natural) {d :: Natural}.
(SHA2 algorithm context k,
 d
 ~ Div
     (PaddedLength k (ChunkSize algorithm) (2 * WordSize algorithm))
     (ChunkSize algorithm)) =>
ByteString k context -> ByteString (ResultSize algorithm) context
sha2 @"SHA256" ByteString msgLen ctx
msg

        target :: UInt keyLen 'Auto ctx
        target :: UInt keyLen 'Auto ctx
target = UInt keyLen 'Auto ctx -> UInt keyLen 'Auto ctx
forall a. NFData a => a -> a
force (UInt keyLen 'Auto ctx -> UInt keyLen 'Auto ctx)
-> UInt keyLen 'Auto ctx -> UInt keyLen 'Auto ctx
forall a b. (a -> b) -> a -> b
$ UInt keyLen 'Auto ctx
-> UInt PubExponentSize 'Auto ctx
-> UInt keyLen 'Auto ctx
-> UInt keyLen 'Auto ctx
forall (c :: (Type -> Type) -> Type) (n :: Natural) (p :: Natural)
       (m :: Natural) (r :: RegisterSize).
(Symbolic c, KnownRegisterSize r, KnownNat p, KnownNat n,
 KnownNat m, KnownNat (2 * m), KnownRegisters c (2 * m) r,
 KnownNat (Ceil (GetRegisterSize (BaseField c) (2 * m) r) 16),
 NFData (c (Vector (NumberOfRegisters (BaseField c) (2 * m) r)))) =>
UInt n r c -> UInt p r c -> UInt m r c -> UInt m r c
expMod (Signature keyLen ctx -> UInt keyLen 'Auto ctx
forall a b. Iso a b => a -> b
from Signature keyLen ctx
sig :: UInt keyLen 'Auto ctx) UInt PubExponentSize 'Auto ctx
pubE UInt keyLen 'Auto ctx
pubN

        input :: UInt keyLen 'Auto ctx
        input :: UInt keyLen 'Auto ctx
input = UInt keyLen 'Auto ctx -> UInt keyLen 'Auto ctx
forall a. NFData a => a -> a
force (UInt keyLen 'Auto ctx -> UInt keyLen 'Auto ctx)
-> UInt keyLen 'Auto ctx -> UInt keyLen 'Auto ctx
forall a b. (a -> b) -> a -> b
$ UInt 256 'Auto ctx -> UInt keyLen 'Auto ctx
forall a b. Resize a b => a -> b
resize (ByteString 256 ctx -> UInt 256 'Auto ctx
forall a b. Iso a b => a -> b
from ByteString 256 ctx
h :: UInt 256 'Auto ctx)

signVar
    :: forall keyLen msgLen ctx
    .  RSA keyLen msgLen ctx
    => VarByteString msgLen ctx
    -> PrivateKey keyLen ctx
    -> Signature keyLen ctx
signVar :: forall (keyLen :: Natural) (msgLen :: Natural)
       (ctx :: (Type -> Type) -> Type).
RSA keyLen msgLen ctx =>
VarByteString msgLen ctx
-> PrivateKey keyLen ctx -> Signature keyLen ctx
signVar VarByteString msgLen ctx
msg PrivateKey{UInt keyLen 'Auto ctx
prvD :: forall (keyLen :: Natural) (ctx :: (Type -> Type) -> Type).
PrivateKey keyLen ctx -> UInt keyLen 'Auto ctx
prvN :: forall (keyLen :: Natural) (ctx :: (Type -> Type) -> Type).
PrivateKey keyLen ctx -> UInt keyLen 'Auto ctx
prvD :: UInt keyLen 'Auto ctx
prvN :: UInt keyLen 'Auto ctx
..} = Signature keyLen ctx -> Signature keyLen ctx
forall a. NFData a => a -> a
force (Signature keyLen ctx -> Signature keyLen ctx)
-> Signature keyLen ctx -> Signature keyLen ctx
forall a b. (a -> b) -> a -> b
$ UInt keyLen 'Auto ctx -> Signature keyLen ctx
forall a b. Iso a b => a -> b
from (UInt keyLen 'Auto ctx -> Signature keyLen ctx)
-> UInt keyLen 'Auto ctx -> Signature keyLen ctx
forall a b. (a -> b) -> a -> b
$ UInt 256 'Auto ctx
-> UInt keyLen 'Auto ctx
-> UInt keyLen 'Auto ctx
-> UInt keyLen 'Auto ctx
forall (c :: (Type -> Type) -> Type) (n :: Natural) (p :: Natural)
       (m :: Natural) (r :: RegisterSize).
(Symbolic c, KnownRegisterSize r, KnownNat p, KnownNat n,
 KnownNat m, KnownNat (2 * m), KnownRegisters c (2 * m) r,
 KnownNat (Ceil (GetRegisterSize (BaseField c) (2 * m) r) 16),
 NFData (c (Vector (NumberOfRegisters (BaseField c) (2 * m) r)))) =>
UInt n r c -> UInt p r c -> UInt m r c -> UInt m r c
expMod UInt 256 'Auto ctx
msgI UInt keyLen 'Auto ctx
prvD UInt keyLen 'Auto ctx
prvN
    where
        h :: ByteString 256 ctx
        h :: ByteString 256 ctx
h = forall (algorithm :: Symbol) (context :: (Type -> Type) -> Type)
       (k :: Natural) {d :: Natural}.
(SHA2 algorithm context k, KnownNat (Log2 (ChunkSize algorithm)),
 d
 ~ Div
     (PaddedLength k (ChunkSize algorithm) (2 * WordSize algorithm))
     (ChunkSize algorithm)) =>
VarByteString k context
-> ByteString (ResultSize algorithm) context
sha2Var @"SHA256" VarByteString msgLen ctx
msg

        msgI :: UInt 256 'Auto ctx
        msgI :: UInt 256 'Auto ctx
msgI = ByteString 256 ctx -> UInt 256 'Auto ctx
forall a b. Iso a b => a -> b
from ByteString 256 ctx
h

verifyVar
    :: forall keyLen msgLen ctx
    .  RSA keyLen msgLen ctx
    => VarByteString msgLen ctx
    -> Signature keyLen ctx
    -> PublicKey keyLen ctx
    -> (Bool ctx, ByteString 256 ctx)
verifyVar :: forall (keyLen :: Natural) (msgLen :: Natural)
       (ctx :: (Type -> Type) -> Type).
RSA keyLen msgLen ctx =>
VarByteString msgLen ctx
-> Signature keyLen ctx
-> PublicKey keyLen ctx
-> (Bool ctx, ByteString 256 ctx)
verifyVar VarByteString msgLen ctx
msg Signature keyLen ctx
sig PublicKey{UInt keyLen 'Auto ctx
UInt PubExponentSize 'Auto ctx
pubE :: forall (keyLen :: Natural) (ctx :: (Type -> Type) -> Type).
PublicKey keyLen ctx -> UInt PubExponentSize 'Auto ctx
pubN :: forall (keyLen :: Natural) (ctx :: (Type -> Type) -> Type).
PublicKey keyLen ctx -> UInt keyLen 'Auto ctx
pubE :: UInt PubExponentSize 'Auto ctx
pubN :: UInt keyLen 'Auto ctx
..} = (UInt keyLen 'Auto ctx
target UInt keyLen 'Auto ctx
-> UInt keyLen 'Auto ctx -> BooleanOf (UInt keyLen 'Auto ctx)
forall a. Eq a => a -> a -> BooleanOf a
== UInt keyLen 'Auto ctx
input, ByteString 256 ctx
h)
    where
        h :: ByteString 256 ctx
        h :: ByteString 256 ctx
h = forall (algorithm :: Symbol) (context :: (Type -> Type) -> Type)
       (k :: Natural) {d :: Natural}.
(SHA2 algorithm context k, KnownNat (Log2 (ChunkSize algorithm)),
 d
 ~ Div
     (PaddedLength k (ChunkSize algorithm) (2 * WordSize algorithm))
     (ChunkSize algorithm)) =>
VarByteString k context
-> ByteString (ResultSize algorithm) context
sha2Var @"SHA256" VarByteString msgLen ctx
msg

        target :: UInt keyLen 'Auto ctx
        target :: UInt keyLen 'Auto ctx
target = UInt keyLen 'Auto ctx -> UInt keyLen 'Auto ctx
forall a. NFData a => a -> a
force (UInt keyLen 'Auto ctx -> UInt keyLen 'Auto ctx)
-> UInt keyLen 'Auto ctx -> UInt keyLen 'Auto ctx
forall a b. (a -> b) -> a -> b
$ UInt keyLen 'Auto ctx
-> UInt PubExponentSize 'Auto ctx
-> UInt keyLen 'Auto ctx
-> UInt keyLen 'Auto ctx
forall (c :: (Type -> Type) -> Type) (n :: Natural) (p :: Natural)
       (m :: Natural) (r :: RegisterSize).
(Symbolic c, KnownRegisterSize r, KnownNat p, KnownNat n,
 KnownNat m, KnownNat (2 * m), KnownRegisters c (2 * m) r,
 KnownNat (Ceil (GetRegisterSize (BaseField c) (2 * m) r) 16),
 NFData (c (Vector (NumberOfRegisters (BaseField c) (2 * m) r)))) =>
UInt n r c -> UInt p r c -> UInt m r c -> UInt m r c
expMod (Signature keyLen ctx -> UInt keyLen 'Auto ctx
forall a b. Iso a b => a -> b
from Signature keyLen ctx
sig :: UInt keyLen 'Auto ctx) UInt PubExponentSize 'Auto ctx
pubE UInt keyLen 'Auto ctx
pubN

        input :: UInt keyLen 'Auto ctx
        input :: UInt keyLen 'Auto ctx
input = UInt keyLen 'Auto ctx -> UInt keyLen 'Auto ctx
forall a. NFData a => a -> a
force (UInt keyLen 'Auto ctx -> UInt keyLen 'Auto ctx)
-> UInt keyLen 'Auto ctx -> UInt keyLen 'Auto ctx
forall a b. (a -> b) -> a -> b
$ UInt 256 'Auto ctx -> UInt keyLen 'Auto ctx
forall a b. Resize a b => a -> b
resize (ByteString 256 ctx -> UInt 256 'Auto ctx
forall a b. Iso a b => a -> b
from ByteString 256 ctx
h :: UInt 256 'Auto ctx)