{-# LANGUAGE TypeOperators #-}

module ZkFold.Base.Protocol.IVC.Predicate where

import           Data.Binary                           (Binary)
import           GHC.Generics                          (U1 (..), (:*:) (..))
import           Prelude                               hiding (Num (..), drop, head, replicate, take, zipWith)

import           ZkFold.Base.Data.Package              (packed, unpacked)
import           ZkFold.Base.Protocol.IVC.StepFunction (FunctorAssumptions, StepFunction, StepFunctionAssumptions)
import           ZkFold.Symbolic.Class
import           ZkFold.Symbolic.Compiler              (ArithmeticCircuit, compileWith, guessOutput, hlmap)
import           ZkFold.Symbolic.Data.FieldElement     (FieldElement (..))
import           ZkFold.Symbolic.Interpreter           (Interpreter (..))

type PredicateCircuit a i p = ArithmeticCircuit a (i :*: p) i U1

data Predicate a i p = Predicate
    { forall a (i :: Type -> Type) (p :: Type -> Type).
Predicate a i p -> i a -> p a -> i a
predicateEval    :: i a -> p a -> i a
    , forall a (i :: Type -> Type) (p :: Type -> Type).
Predicate a i p -> PredicateCircuit a i p
predicateCircuit :: PredicateCircuit a i p
    }

type PredicateAssumptions a i p =
    ( Arithmetic a
    , Binary a
    , FunctorAssumptions i
    , FunctorAssumptions p
    )

predicate :: forall a i p . PredicateAssumptions a i p
    => StepFunction a i p -> Predicate a i p
predicate :: forall a (i :: Type -> Type) (p :: Type -> Type).
PredicateAssumptions a i p =>
StepFunction a i p -> Predicate a i p
predicate StepFunction a i p
func =
    let
        func' :: forall f ctx . StepFunctionAssumptions a f ctx => ctx i -> ctx p -> ctx i
        func' :: forall f (ctx :: (Type -> Type) -> Type).
StepFunctionAssumptions a f ctx =>
ctx i -> ctx p -> ctx i
func' ctx i
x' ctx p
u' =
            let
                x :: i (FieldElement ctx)
x = ctx Par1 -> FieldElement ctx
forall (c :: (Type -> Type) -> Type). c Par1 -> FieldElement c
FieldElement (ctx Par1 -> FieldElement ctx)
-> i (ctx Par1) -> i (FieldElement ctx)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> ctx i -> i (ctx Par1)
forall (c :: (Type -> Type) -> Type) (f :: Type -> Type).
(Package c, Functor f) =>
c f -> f (c Par1)
unpacked ctx i
x'
                u :: p (FieldElement ctx)
u = ctx Par1 -> FieldElement ctx
forall (c :: (Type -> Type) -> Type). c Par1 -> FieldElement c
FieldElement (ctx Par1 -> FieldElement ctx)
-> p (ctx Par1) -> p (FieldElement ctx)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> ctx p -> p (ctx Par1)
forall (c :: (Type -> Type) -> Type) (f :: Type -> Type).
(Package c, Functor f) =>
c f -> f (c Par1)
unpacked ctx p
u'
            in
                i (ctx Par1) -> ctx i
forall (c :: (Type -> Type) -> Type) (f :: Type -> Type).
(Package c, Foldable f, Functor f) =>
f (c Par1) -> c f
packed (i (ctx Par1) -> ctx i)
-> (i (FieldElement ctx) -> i (ctx Par1))
-> i (FieldElement ctx)
-> ctx i
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FieldElement ctx -> ctx Par1)
-> i (FieldElement ctx) -> i (ctx Par1)
forall a b. (a -> b) -> i a -> i b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap FieldElement ctx -> ctx Par1
forall (c :: (Type -> Type) -> Type). FieldElement c -> c Par1
fromFieldElement (i (FieldElement ctx) -> ctx i) -> i (FieldElement ctx) -> ctx i
forall a b. (a -> b) -> a -> b
$ i (FieldElement ctx)
-> p (FieldElement ctx) -> i (FieldElement ctx)
StepFunction a i p
func i (FieldElement ctx)
x p (FieldElement ctx)
u

        predicateEval :: i a -> p a -> i a
        predicateEval :: i a -> p a -> i a
predicateEval i a
x p a
u = Interpreter a i -> i a
forall {k} (a :: k) (f :: k -> Type). Interpreter a f -> f a
runInterpreter (Interpreter a i -> i a) -> Interpreter a i -> i a
forall a b. (a -> b) -> a -> b
$ Interpreter a i -> Interpreter a p -> Interpreter a i
forall f (ctx :: (Type -> Type) -> Type).
StepFunctionAssumptions a f ctx =>
ctx i -> ctx p -> ctx i
func' (i a -> Interpreter a i
forall {k} (a :: k) (f :: k -> Type). f a -> Interpreter a f
Interpreter i a
x) (p a -> Interpreter a p
forall {k} (a :: k) (f :: k -> Type). f a -> Interpreter a f
Interpreter p a
u)

        predicateCircuit :: PredicateCircuit a i p
        predicateCircuit :: PredicateCircuit a i p
predicateCircuit =
            (forall x. i x -> (:*:) U1 i x)
-> ArithmeticCircuit a (i :*: p) (U1 :*: i) U1
-> PredicateCircuit a i p
forall (i :: Type -> Type) (j :: Type -> Type) (o :: Type -> Type)
       a (p :: Type -> Type).
(Representable i, Representable j, Ord (Rep j), Functor o) =>
(forall x. j x -> i x)
-> ArithmeticCircuit a p i o -> ArithmeticCircuit a p j o
hlmap (U1 x
forall k (p :: k). U1 p
U1 U1 x -> i x -> (:*:) U1 i x
forall k (f :: k -> Type) (g :: k -> Type) (p :: k).
f p -> g p -> (:*:) f g p
:*:) (ArithmeticCircuit a (i :*: p) (U1 :*: i) U1
 -> PredicateCircuit a i p)
-> ArithmeticCircuit a (i :*: p) (U1 :*: i) U1
-> PredicateCircuit a i p
forall a b. (a -> b) -> a -> b
$
            forall a y (p :: Type -> Type) (i :: Type -> Type)
       (q :: Type -> Type) (j :: Type -> Type) s f
       (c0 :: (Type -> Type) -> Type) (c1 :: (Type -> Type) -> Type).
(CompilesWith c0 s f, c0 ~ ArithmeticCircuit a p i,
 Representable p, Representable i, RestoresFrom c1 y,
 c1 ~ ArithmeticCircuit a q j, Binary a, Binary (Rep p),
 Binary (Rep i), Binary (Rep j), Ord (Rep i), Ord (Rep j),
 Binary (Rep q)) =>
(c0 (Layout f) -> c1 (Layout y))
-> (forall x. p x -> i x -> (Payload s x, Layout s x)) -> f -> y
compileWith @a ArithmeticCircuit a (i :*: p) U1 i
-> ArithmeticCircuit a (i :*: p) (U1 :*: i) U1
ArithmeticCircuit
  a
  (i :*: p)
  U1
  (Layout
     (ArithmeticCircuit a (i :*: p) U1 i
      -> ArithmeticCircuit a (i :*: p) U1 p
      -> ArithmeticCircuit a (i :*: p) U1 i))
-> ArithmeticCircuit
     a
     (i :*: p)
     (U1 :*: i)
     (Layout (ArithmeticCircuit a (i :*: p) (U1 :*: i) U1))
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
(Arithmetic a, Binary a, Binary (Rep p), Binary (Rep i),
 Binary (Rep o), Ord (Rep i), Ord (Rep o), NFData (Rep i),
 NFData (Rep o), Representable i, Representable o, Foldable o) =>
ArithmeticCircuit a p i o -> ArithmeticCircuit a p (i :*: o) U1
guessOutput (\(i x
i :*: p x
p) U1 x
U1 -> (U1 x
forall k (p :: k). U1 p
U1 U1 x -> (:*:) U1 U1 x -> (:*:) U1 (U1 :*: U1) x
forall k (f :: k -> Type) (g :: k -> Type) (p :: k).
f p -> g p -> (:*:) f g p
:*: U1 x
forall k (p :: k). U1 p
U1 U1 x -> U1 x -> (:*:) U1 U1 x
forall k (f :: k -> Type) (g :: k -> Type) (p :: k).
f p -> g p -> (:*:) f g p
:*: U1 x
forall k (p :: k). U1 p
U1, i x
i i x -> (:*:) p U1 x -> (:*:) i (p :*: U1) x
forall k (f :: k -> Type) (g :: k -> Type) (p :: k).
f p -> g p -> (:*:) f g p
:*: p x
p p x -> U1 x -> (:*:) p U1 x
forall k (f :: k -> Type) (g :: k -> Type) (p :: k).
f p -> g p -> (:*:) f g p
:*: U1 x
forall k (p :: k). U1 p
U1)) ArithmeticCircuit a (i :*: p) U1 i
-> ArithmeticCircuit a (i :*: p) U1 p
-> ArithmeticCircuit a (i :*: p) U1 i
forall f (ctx :: (Type -> Type) -> Type).
StepFunctionAssumptions a f ctx =>
ctx i -> ctx p -> ctx i
func'
    in Predicate {PredicateCircuit a i p
i a -> p a -> i a
predicateEval :: i a -> p a -> i a
predicateCircuit :: PredicateCircuit a i p
predicateEval :: i a -> p a -> i a
predicateCircuit :: PredicateCircuit a i p
..}