{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE TypeOperators       #-}

module ZkFold.Base.Protocol.IVC.FiatShamir where

import           Data.Constraint                       (withDict)
import           Data.Constraint.Nat                   (plusMinusInverse1)
import           Prelude                               hiding (Bool (..), Eq (..), init, length, pi, scanl, unzip)

import           ZkFold.Base.Algebra.Basic.Number      (KnownNat, type (-))
import           ZkFold.Base.Data.Vector               (Vector, init, item, scanl, unfold)
import           ZkFold.Base.Protocol.IVC.CommitOpen
import           ZkFold.Base.Protocol.IVC.Oracle       (HashAlgorithm, RandomOracle (..))
import           ZkFold.Base.Protocol.IVC.SpecialSound (SpecialSoundProtocol (..))

type FiatShamir k i p c m o f = SpecialSoundProtocol 1 i p (Vector k (m, c f)) (Vector k (c f), o) f

-- The transcript of the Fiat-Shamired protocol (ignoring the last round)
transcript :: forall algo k c f .
    ( HashAlgorithm algo f
    , RandomOracle algo f f
    , RandomOracle algo (c f) f
    ) => f -> Vector k (c f) -> Vector (k-1) f
transcript :: forall {k} (algo :: k) (k :: Natural) (c :: Type -> Type) f.
(HashAlgorithm algo f, RandomOracle algo f f,
 RandomOracle algo (c f) f) =>
f -> Vector k (c f) -> Vector (k - 1) f
transcript f
r0 Vector k (c f)
cs = Dict (((k + 1) - 1) ~ k)
-> ((((k + 1) - 1) ~ k) => Vector (k - 1) f) -> Vector (k - 1) f
forall (c :: Constraint) e r. HasDict c e => e -> (c => r) -> r
withDict (forall (n :: Natural) (m :: Natural). Dict (((m + n) - n) ~ m)
plusMinusInverse1 @1 @k) (((((k + 1) - 1) ~ k) => Vector (k - 1) f) -> Vector (k - 1) f)
-> ((((k + 1) - 1) ~ k) => Vector (k - 1) f) -> Vector (k - 1) f
forall a b. (a -> b) -> a -> b
$ Vector k f -> Vector (k - 1) f
forall (size :: Natural) a. Vector size a -> Vector (size - 1) a
init (Vector k f -> Vector (k - 1) f) -> Vector k f -> Vector (k - 1) f
forall a b. (a -> b) -> a -> b
$ Vector (k + 1) f -> Vector ((k + 1) - 1) f
forall (size :: Natural) a. Vector size a -> Vector (size - 1) a
init (Vector (k + 1) f -> Vector ((k + 1) - 1) f)
-> Vector (k + 1) f -> Vector ((k + 1) - 1) f
forall a b. (a -> b) -> a -> b
$ (f -> c f -> f) -> f -> Vector k (c f) -> Vector (k + 1) f
forall (size :: Natural) a b.
(b -> a -> b) -> b -> Vector size a -> Vector (size + 1) b
scanl (((f, c f) -> f) -> f -> c f -> f
forall a b c. ((a, b) -> c) -> a -> b -> c
curry (forall (algo :: k) x a. RandomOracle algo x a => x -> a
forall {k} (algo :: k) x a. RandomOracle algo x a => x -> a
oracle @algo)) f
r0 Vector k (c f)
cs

fiatShamir :: forall algo k i p c m o f .
    ( KnownNat k
    , HashAlgorithm algo f
    , RandomOracle algo f f
    , RandomOracle algo (i f) f
    , RandomOracle algo (c f) f
    ) => CommitOpen k i p c m o f -> FiatShamir k i p c m o f
fiatShamir :: forall {k} (algo :: k) (k :: Natural) (i :: Type -> Type)
       (p :: Type -> Type) (c :: Type -> Type) m o f.
(KnownNat k, HashAlgorithm algo f, RandomOracle algo f f,
 RandomOracle algo (i f) f, RandomOracle algo (c f) f) =>
CommitOpen k i p c m o f -> FiatShamir k i p c m o f
fiatShamir SpecialSoundProtocol {i f -> p f -> i f
i f -> p f -> f -> Natural -> (m, c f)
i f -> Vector k (m, c f) -> Vector (k - 1) f -> (Vector k (c f), o)
input :: i f -> p f -> i f
prover :: i f -> p f -> f -> Natural -> (m, c f)
verifier :: i f -> Vector k (m, c f) -> Vector (k - 1) f -> (Vector k (c f), o)
input :: forall (k :: Natural) (i :: Type -> Type) (p :: Type -> Type) m o
       f.
SpecialSoundProtocol k i p m o f -> i f -> p f -> i f
prover :: forall (k :: Natural) (i :: Type -> Type) (p :: Type -> Type) m o
       f.
SpecialSoundProtocol k i p m o f -> i f -> p f -> f -> Natural -> m
verifier :: forall (k :: Natural) (i :: Type -> Type) (p :: Type -> Type) m o
       f.
SpecialSoundProtocol k i p m o f
-> i f -> Vector k m -> Vector (k - 1) f -> o
..} =
    let
        prover' :: i f -> p f -> f -> Natural -> Vector k (m, c f)
prover' i f
pi0 p f
w f
_ Natural
_ =
            let r0 :: f
r0 = forall (algo :: k) x a. RandomOracle algo x a => x -> a
forall {k} (algo :: k) x a. RandomOracle algo x a => x -> a
oracle @algo i f
pi0
                f :: (f, Natural) -> ((m, c f), (f, Natural))
f (f
r, Natural
k) =
                    let (m
m', c f
c') = i f -> p f -> f -> Natural -> (m, c f)
prover i f
pi0 p f
w f
r Natural
k
                    in ((m
m', c f
c'), (forall (algo :: k) x a. RandomOracle algo x a => x -> a
forall {k} (algo :: k) x a. RandomOracle algo x a => x -> a
oracle @algo (f
r, c f
c'), Natural
k Natural -> Natural -> Natural
forall a. Num a => a -> a -> a
+ Natural
1))
            in ((f, Natural) -> ((m, c f), (f, Natural)))
-> (f, Natural) -> Vector k (m, c f)
forall (size :: Natural) a b.
KnownNat size =>
(b -> (a, b)) -> b -> Vector size a
unfold (f, Natural) -> ((m, c f), (f, Natural))
f (f
r0, Natural
1)

        verifier' :: i f
-> Vector 1 (Vector k (m, c f))
-> Vector 0 f
-> (Vector k (c f), o)
verifier' i f
pi Vector 1 (Vector k (m, c f))
pms' Vector 0 f
_ =
            let pms :: Vector k (m, c f)
pms = Vector 1 (Vector k (m, c f)) -> Vector k (m, c f)
forall a. Vector 1 a -> a
item Vector 1 (Vector k (m, c f))
pms'
                r0 :: f
r0 = forall (algo :: k) x a. RandomOracle algo x a => x -> a
forall {k} (algo :: k) x a. RandomOracle algo x a => x -> a
oracle @algo i f
pi :: f
                rs :: Vector (k - 1) f
rs = forall (algo :: k) (k :: Natural) (c :: Type -> Type) f.
(HashAlgorithm algo f, RandomOracle algo f f,
 RandomOracle algo (c f) f) =>
f -> Vector k (c f) -> Vector (k - 1) f
forall {k} (algo :: k) (k :: Natural) (c :: Type -> Type) f.
(HashAlgorithm algo f, RandomOracle algo f f,
 RandomOracle algo (c f) f) =>
f -> Vector k (c f) -> Vector (k - 1) f
transcript @algo f
r0 (Vector k (c f) -> Vector (k - 1) f)
-> Vector k (c f) -> Vector (k - 1) f
forall a b. (a -> b) -> a -> b
$ ((m, c f) -> c f) -> Vector k (m, c f) -> Vector k (c f)
forall a b. (a -> b) -> Vector k a -> Vector k b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (m, c f) -> c f
forall a b. (a, b) -> b
snd Vector k (m, c f)
pms
            in i f -> Vector k (m, c f) -> Vector (k - 1) f -> (Vector k (c f), o)
verifier i f
pi Vector k (m, c f)
pms Vector (k - 1) f
rs
    in
        (i f -> p f -> i f)
-> (i f -> p f -> f -> Natural -> Vector k (m, c f))
-> (i f
    -> Vector 1 (Vector k (m, c f))
    -> Vector (1 - 1) f
    -> (Vector k (c f), o))
-> SpecialSoundProtocol
     1 i p (Vector k (m, c f)) (Vector k (c f), o) f
forall (k :: Natural) (i :: Type -> Type) (p :: Type -> Type) m o
       f.
(i f -> p f -> i f)
-> (i f -> p f -> f -> Natural -> m)
-> (i f -> Vector k m -> Vector (k - 1) f -> o)
-> SpecialSoundProtocol k i p m o f
SpecialSoundProtocol i f -> p f -> i f
input i f -> p f -> f -> Natural -> Vector k (m, c f)
prover' i f
-> Vector 1 (Vector k (m, c f))
-> Vector 0 f
-> (Vector k (c f), o)
i f
-> Vector 1 (Vector k (m, c f))
-> Vector (1 - 1) f
-> (Vector k (c f), o)
verifier'