{-# LANGUAGE AllowAmbiguousTypes  #-}
{-# LANGUAGE TypeOperators        #-}
{-# LANGUAGE UndecidableInstances #-}

module ZkFold.Base.Protocol.Protostar.SpecialSound where

import           Data.Functor.Rep                                      (Representable (..))
import           Data.Map.Strict                                       (elems)
import           GHC.Generics                                          ((:*:) (..))
import           Prelude                                               (($))

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Base.Algebra.Basic.Number
import           ZkFold.Base.Data.Vector                               (Vector)
import qualified ZkFold.Base.Protocol.Protostar.AlgebraicMap           as AM
import           ZkFold.Base.Protocol.Protostar.ArithmetizableFunction (ArithmetizableFunction (..))
import           ZkFold.Symbolic.Class
import           ZkFold.Symbolic.Compiler

{-- | Section 3.1

The protocol Πsps has 3 essential parameters k, d, l ∈ N, meaning that Πsps is a (2k − 1)-
move protocol with verifier degree d and output length l (i.e. the verifier checks l degree
d algebraic equations). In each round i (1 ≤ i ≤ k), the prover Psps(pi, w, [mj , rj], j=1 to i-1)
generates the next message mi on input the public input pi, the witness w, and the current
transcript [mj , rj], j=1 to i-1, and sends mi to the verifier; the verifier replies with a random
challenge ri ∈ F. After the final message mk, the verifier computes the algebraic map Vsps
and checks that the output is a zero vector of length l.

--}
class SpecialSoundProtocol f i p m c d k a where
  type VerifierOutput f i p m c d k a

  input :: a
    -> i f                          -- ^ previous public input
    -> p f                          -- ^ witness
    -> i f                          -- ^ public input

  prover :: a
    -> i f                          -- ^ previous public input
    -> p f                          -- ^ witness
    -> f                            -- ^ current random challenge
    -> Natural                      -- ^ round number (starting from 1)
    -> m                            -- ^ prover message

  verifier :: a
    -> i f                          -- ^ public input
    -> Vector k m                   -- ^ prover messages
    -> Vector (k-1) f               -- ^ random challenges
    -> VerifierOutput f i p m c d k a -- ^ verifier output

instance (Arithmetic a, Representable i, Representable p, KnownNat (d + 1))
    => SpecialSoundProtocol a i p [a] c d 1 (ArithmetizableFunction a i p) where
  type VerifierOutput a i p [a] c d 1 (ArithmetizableFunction a i p) = [a]

  input :: ArithmetizableFunction a i p -> i a -> p a -> i a
input ArithmetizableFunction {ArithmeticCircuit a (i :*: p) i U1
i a -> p a -> i a
afEval :: i a -> p a -> i a
afCircuit :: ArithmeticCircuit a (i :*: p) i U1
afEval :: forall a (i :: Type -> Type) (p :: Type -> Type).
ArithmetizableFunction a i p -> i a -> p a -> i a
afCircuit :: forall a (i :: Type -> Type) (p :: Type -> Type).
ArithmetizableFunction a i p -> ArithmeticCircuit a (i :*: p) i U1
..} = i a -> p a -> i a
afEval

  -- | Just return the witness values on the previous public input
  prover :: ArithmetizableFunction a i p -> i a -> p a -> a -> Natural -> [a]
prover ArithmetizableFunction {ArithmeticCircuit a (i :*: p) i U1
i a -> p a -> i a
afEval :: forall a (i :: Type -> Type) (p :: Type -> Type).
ArithmetizableFunction a i p -> i a -> p a -> i a
afCircuit :: forall a (i :: Type -> Type) (p :: Type -> Type).
ArithmetizableFunction a i p -> ArithmeticCircuit a (i :*: p) i U1
afEval :: i a -> p a -> i a
afCircuit :: ArithmeticCircuit a (i :*: p) i U1
..} i a
pi0 p a
w a
_ Natural
_ = Map ByteString a -> [a]
forall k a. Map k a -> [a]
elems (Map ByteString a -> [a]) -> Map ByteString a -> [a]
forall a b. (a -> b) -> a -> b
$ ArithmeticCircuit a (i :*: p) i U1
-> (:*:) i p a -> i a -> Map ByteString a
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
(Arithmetic a, Representable p, Representable i) =>
ArithmeticCircuit a p i o -> p a -> i a -> Map ByteString a
witnessGenerator ArithmeticCircuit a (i :*: p) i U1
afCircuit (i a
pi0 i a -> p a -> (:*:) i p a
forall k (f :: k -> Type) (g :: k -> Type) (p :: k).
f p -> g p -> (:*:) f g p
:*: p a
w) (i a -> p a -> i a
afEval i a
pi0 p a
w)

  -- | Evaluate the algebraic map on public inputs and prover messages
  --
  verifier :: ArithmetizableFunction a i p
-> i a
-> Vector 1 [a]
-> Vector (1 - 1) a
-> VerifierOutput a i p [a] c d 1 (ArithmetizableFunction a i p)
verifier ArithmetizableFunction a i p
af i a
pi Vector 1 [a]
pm Vector (1 - 1) a
ts = forall f (i :: Type -> Type) (d :: Natural) a (k :: Natural).
AlgebraicMap f i d a =>
a -> i f -> Vector k [f] -> Vector (k - 1) f -> f -> [f]
AM.algebraicMap @_ @_ @d ArithmetizableFunction a i p
af i a
pi Vector 1 [a]
pm Vector (1 - 1) a
ts a
forall a. MultiplicativeMonoid a => a
one