{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE TypeOperators       #-}

module ZkFold.Base.Protocol.NonInteractiveProof.Internal where

import           Crypto.Hash.BLAKE2.BLAKE2b                 (hash)
import           Data.ByteString                            (ByteString)
import           Data.Maybe                                 (fromJust)
import qualified Data.Vector                                as V
import           Data.Word                                  (Word8)
import           Numeric.Natural                            (Natural)
import           Prelude                                    hiding (Num ((*)), sum)

import           ZkFold.Base.Algebra.Basic.Class            (Field, MultiplicativeSemigroup ((*)), sum)
import           ZkFold.Base.Algebra.EllipticCurve.Class    (EllipticCurve (..), Point)
import           ZkFold.Base.Algebra.Polynomials.Univariate (Poly, PolyVec, fromPolyVec)
import           ZkFold.Base.Data.ByteString

class Monoid ts => ToTranscript ts a where
    toTranscript :: a -> ts

instance Binary a => ToTranscript ByteString a where
    toTranscript :: a -> ByteString
toTranscript = a -> ByteString
forall a. Binary a => a -> ByteString
toByteString

transcript :: ToTranscript ts a => ts -> a -> ts
transcript :: forall ts a. ToTranscript ts a => ts -> a -> ts
transcript ts
ts a
a = ts
ts ts -> ts -> ts
forall a. Semigroup a => a -> a -> a
<> a -> ts
forall ts a. ToTranscript ts a => a -> ts
toTranscript a
a

class Monoid ts => FromTranscript ts a where
    fromTranscript :: ts -> a

instance Binary a => FromTranscript ByteString a where
    fromTranscript :: ByteString -> a
fromTranscript = Maybe a -> a
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe a -> a) -> (ByteString -> Maybe a) -> ByteString -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Maybe a
forall a. Binary a => ByteString -> Maybe a
fromByteString (ByteString -> Maybe a)
-> (ByteString -> ByteString) -> ByteString -> Maybe a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> ByteString -> ByteString -> ByteString
hash Int
28 ByteString
forall a. Monoid a => a
mempty

challenge :: forall ts a . FromTranscript ts a => ts -> a
challenge :: forall ts a. FromTranscript ts a => ts -> a
challenge = ts -> a
forall ts a. FromTranscript ts a => ts -> a
fromTranscript

challenges :: (ToTranscript ts Word8, FromTranscript ts a) => ts -> Natural -> ([a], ts)
challenges :: forall ts a.
(ToTranscript ts Word8, FromTranscript ts a) =>
ts -> Natural -> ([a], ts)
challenges ts
ts0 Natural
n = ts -> Natural -> [a] -> ([a], ts)
forall {t} {t} {a}.
(Eq t, Num t, ToTranscript t Word8, FromTranscript t a) =>
t -> t -> [a] -> ([a], t)
go ts
ts0 Natural
n []
  where
    go :: t -> t -> [a] -> ([a], t)
go t
ts t
0 [a]
acc = ([a]
acc, t
ts)
    go t
ts t
k [a]
acc =
        let c :: a
c   = t -> a
forall ts a. FromTranscript ts a => ts -> a
challenge t
ts
            ts' :: t
ts' = t
ts t -> Word8 -> t
forall ts a. ToTranscript ts a => ts -> a -> ts
`transcript` (Word8
0 :: Word8)
        in t -> t -> [a] -> ([a], t)
go t
ts' (t
k t -> t -> t
forall a. Num a => a -> a -> a
- t
1) (a
c a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a]
acc)

class NonInteractiveProof a core where
    type Transcript a

    type SetupProve a

    type SetupVerify a

    type Witness a

    type Input a

    type Proof a

    setupProve :: a -> SetupProve a

    setupVerify :: a -> SetupVerify a

    prove :: SetupProve a -> Witness a -> (Input a, Proof a)

    verify :: SetupVerify a -> Input a -> Proof a -> Bool

class (EllipticCurve curve) => CoreFunction curve core where
    msm :: (f ~ ScalarField curve) => V.Vector (Point curve) -> PolyVec f size -> Point curve

    polyMul :: (f ~ ScalarField curve, Field f, Eq f) => Poly f -> Poly f -> Poly f

data HaskellCore

instance (EllipticCurve curve, f ~ ScalarField curve) => CoreFunction curve HaskellCore where
    msm :: forall f (size :: Natural).
(f ~ ScalarField curve) =>
Vector (Point curve) -> PolyVec f size -> Point curve
msm Vector (Point curve)
gs PolyVec f size
f = Vector (Point curve) -> Point curve
forall (t :: Type -> Type) a.
(Foldable t, AdditiveMonoid a) =>
t a -> a
sum (Vector (Point curve) -> Point curve)
-> Vector (Point curve) -> Point curve
forall a b. (a -> b) -> a -> b
$ (f -> Point curve -> Point curve)
-> Vector f -> Vector (Point curve) -> Vector (Point curve)
forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith f -> Point curve -> Point curve
ScalarField curve -> Point curve -> Point curve
forall curve.
EllipticCurve curve =>
ScalarField curve -> Point curve -> Point curve
mul (PolyVec f size -> Vector f
forall c (size :: Natural). PolyVec c size -> Vector c
fromPolyVec PolyVec f size
f) Vector (Point curve)
gs
    polyMul :: forall f.
(f ~ ScalarField curve, Field f, Eq f) =>
Poly f -> Poly f -> Poly f
polyMul = Poly f -> Poly f -> Poly f
forall a. MultiplicativeSemigroup a => a -> a -> a
(*)