{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE TypeOperators       #-}

module ZkFold.Symbolic.Compiler (
    module ZkFold.Symbolic.Compiler.ArithmeticCircuit,
    compile,
    compileIO,
    compileWith
) where

import           Data.Aeson                                 (FromJSON, ToJSON, ToJSONKey)
import           Data.Binary                                (Binary)
import           Data.Function                              (const, id, (.))
import           Data.Functor.Rep                           (Rep, Representable)
import           Data.Ord                                   (Ord)
import           Data.Tuple                                 (fst, snd)
import           GHC.Generics                               (Par1 (Par1), U1 (..))
import           Prelude                                    (FilePath, IO, Show (..), putStrLn, return, type (~), ($),
                                                             (++))

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Prelude                             (writeFileJSON)
import           ZkFold.Symbolic.Class                      (Symbolic (..), fromCircuit2F)
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit
import           ZkFold.Symbolic.Data.Bool                  (Bool (Bool))
import           ZkFold.Symbolic.Data.Class
import           ZkFold.Symbolic.Data.Input
import           ZkFold.Symbolic.MonadCircuit               (MonadCircuit (..))

{-
    ZkFold Symbolic compiler module dependency order:
    1. ZkFold.Symbolic.Compiler.ArithmeticCircuit.MerkleHash
    2. ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal
    3. ZkFold.Symbolic.Compiler.ArithmeticCircuit.Map
    4. ZkFold.Symbolic.Compiler.ArithmeticCircuit.Instance
    5. ZkFold.Symbolic.Compiler.ArithmeticCircuit
    6. ZkFold.Symbolic.Compiler
-}

-- | A constraint defining what it means
-- for function of type @f@ to be compilable.
type CompilesWith c s f =
  ( SymbolicData f, Context f ~ c, Support f ~ s
  , SymbolicInput s, Context s ~ c, Symbolic c)

-- | A constraint defining what it means
-- for data of type @y@ to be properly restorable.
type RestoresFrom c y =
  (SymbolicOutput y, Context y ~ c, Payload y ~ U1)

compileInternal ::
  (CompilesWith c0 s f, RestoresFrom c1 y, c1 ~ ArithmeticCircuit a p i) =>
  (c0 (Layout f) -> c1 (Layout y)) ->
  c0 (Layout s) -> Payload s (WitnessField c0) -> f -> y
compileInternal :: forall (c0 :: (Type -> Type) -> Type) s f
       (c1 :: (Type -> Type) -> Type) y a (p :: Type -> Type)
       (i :: Type -> Type).
(CompilesWith c0 s f, RestoresFrom c1 y,
 c1 ~ ArithmeticCircuit a p i) =>
(c0 (Layout f) -> c1 (Layout y))
-> c0 (Layout s) -> Payload s (WitnessField c0) -> f -> y
compileInternal c0 (Layout f) -> c1 (Layout y)
opts c0 (Layout s)
sLayout Payload s (WitnessField c0)
sPayload f
f =
  (Proxy (ArithmeticCircuit a p i)
 -> (ArithmeticCircuit a p i (Layout y),
     U1 (WitnessF a (WitVar p i))))
-> y
(Support y
 -> (ArithmeticCircuit a p i (Layout y),
     Payload y (WitnessField (ArithmeticCircuit a p i))))
-> y
forall x (c :: (Type -> Type) -> Type).
(SymbolicData x, Context x ~ c) =>
(Support x -> (c (Layout x), Payload x (WitnessField c))) -> x
forall (c :: (Type -> Type) -> Type).
(Context y ~ c) =>
(Support y -> (c (Layout y), Payload y (WitnessField c))) -> y
restore ((Proxy (ArithmeticCircuit a p i)
  -> (ArithmeticCircuit a p i (Layout y),
      U1 (WitnessF a (WitVar p i))))
 -> y)
-> (c0 (Layout f)
    -> Proxy (ArithmeticCircuit a p i)
    -> (ArithmeticCircuit a p i (Layout y),
        U1 (WitnessF a (WitVar p i))))
-> c0 (Layout f)
-> y
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ArithmeticCircuit a p i (Layout y), U1 (WitnessF a (WitVar p i)))
-> Proxy (ArithmeticCircuit a p i)
-> (ArithmeticCircuit a p i (Layout y),
    U1 (WitnessF a (WitVar p i)))
forall a b. a -> b -> a
const ((ArithmeticCircuit a p i (Layout y), U1 (WitnessF a (WitVar p i)))
 -> Proxy (ArithmeticCircuit a p i)
 -> (ArithmeticCircuit a p i (Layout y),
     U1 (WitnessF a (WitVar p i))))
-> (c0 (Layout f)
    -> (ArithmeticCircuit a p i (Layout y),
        U1 (WitnessF a (WitVar p i))))
-> c0 (Layout f)
-> Proxy (ArithmeticCircuit a p i)
-> (ArithmeticCircuit a p i (Layout y),
    U1 (WitnessF a (WitVar p i)))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (,U1 (WitnessF a (WitVar p i))
forall k (p :: k). U1 p
U1) (ArithmeticCircuit a p i (Layout y)
 -> (ArithmeticCircuit a p i (Layout y),
     U1 (WitnessF a (WitVar p i))))
-> (c0 (Layout f) -> ArithmeticCircuit a p i (Layout y))
-> c0 (Layout f)
-> (ArithmeticCircuit a p i (Layout y),
    U1 (WitnessF a (WitVar p i)))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArithmeticCircuit a p i (Layout y)
-> ArithmeticCircuit a p i (Layout y)
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> ArithmeticCircuit a p i o
optimize (ArithmeticCircuit a p i (Layout y)
 -> ArithmeticCircuit a p i (Layout y))
-> (c0 (Layout f) -> ArithmeticCircuit a p i (Layout y))
-> c0 (Layout f)
-> ArithmeticCircuit a p i (Layout y)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. c0 (Layout f) -> c1 (Layout y)
c0 (Layout f) -> ArithmeticCircuit a p i (Layout y)
opts (c0 (Layout f) -> y) -> c0 (Layout f) -> y
forall a b. (a -> b) -> a -> b
$
    c0 (Layout f)
-> c0 Par1
-> CircuitFun '[Layout f, Par1] (Layout f) c0
-> c0 (Layout f)
forall (c :: (Type -> Type) -> Type) (f :: Type -> Type)
       (g :: Type -> Type) (h :: Type -> Type).
Symbolic c =>
c f -> c g -> CircuitFun '[f, g] h c -> c h
fromCircuit2F (f -> Support f -> Context f (Layout f)
forall x. SymbolicData x => x -> Support x -> Context x (Layout x)
arithmetize f
f s
Support f
input) c0 Par1
Context s Par1
b (CircuitFun '[Layout f, Par1] (Layout f) c0 -> c0 (Layout f))
-> CircuitFun '[Layout f, Par1] (Layout f) c0 -> c0 (Layout f)
forall a b. (a -> b) -> a -> b
$
      \Layout f i
r (Par1 i
i) -> do
        ClosedPoly i (BaseField c0) -> m ()
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
ClosedPoly var a -> m ()
constraint (\i -> x
x -> x
forall a. MultiplicativeMonoid a => a
one x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- i -> x
x i
i)
        Layout f i -> m (Layout f i)
forall a. a -> m a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Layout f i
r
  where
    Bool Context s Par1
b = s -> Bool (Context s)
forall d. SymbolicInput d => d -> Bool (Context d)
isValid s
input
    input :: s
input = (Support s
 -> (Context s (Layout s), Payload s (WitnessField (Context s))))
-> s
forall x (c :: (Type -> Type) -> Type).
(SymbolicData x, Context x ~ c) =>
(Support x -> (c (Layout x), Payload x (WitnessField c))) -> x
forall (c :: (Type -> Type) -> Type).
(Context s ~ c) =>
(Support s -> (c (Layout s), Payload s (WitnessField c))) -> s
restore ((Support s
  -> (Context s (Layout s), Payload s (WitnessField (Context s))))
 -> s)
-> (Support s
    -> (Context s (Layout s), Payload s (WitnessField (Context s))))
-> s
forall a b. (a -> b) -> a -> b
$ (Context s (Layout s), Payload s (WitnessField (Context s)))
-> Support s
-> (Context s (Layout s), Payload s (WitnessField (Context s)))
forall a b. a -> b -> a
const (c0 (Layout s)
Context s (Layout s)
sLayout, Payload s (WitnessField c0)
Payload s (WitnessField (Context s))
sPayload)

-- | @compileWith opts inputT@ compiles a function @f@ into an optimized
-- arithmetic circuit packed inside a suitable 'SymbolicData'.
compileWith ::
  forall a y p i q j s f c0 c1.
  ( CompilesWith c0 s f, c0 ~ ArithmeticCircuit a p i
  , Representable p, Representable i
  , RestoresFrom c1 y, c1 ~ ArithmeticCircuit a q j
  , Binary a, Binary (Rep p), Binary (Rep i), Ord (Rep i)) =>
  -- | Circuit transformation to apply before optimization.
  (c0 (Layout f) -> c1 (Layout y)) ->
  -- | An algorithm to prepare support argument from the circuit input.
  (forall x. p x -> i x -> (Payload s x, Layout s x)) ->
  -- | Function to compile.
  f -> y
compileWith :: forall a y (p :: Type -> Type) (i :: Type -> Type)
       (q :: Type -> Type) (j :: Type -> Type) s f
       (c0 :: (Type -> Type) -> Type) (c1 :: (Type -> Type) -> Type).
(CompilesWith c0 s f, c0 ~ ArithmeticCircuit a p i,
 Representable p, Representable i, RestoresFrom c1 y,
 c1 ~ ArithmeticCircuit a q j, Binary a, Binary (Rep p),
 Binary (Rep i), Ord (Rep i)) =>
(c0 (Layout f) -> c1 (Layout y))
-> (forall x. p x -> i x -> (Payload s x, Layout s x)) -> f -> y
compileWith c0 (Layout f) -> c1 (Layout y)
outputTransform forall x. p x -> i x -> (Payload s x, Layout s x)
inputTransform =
  (c0 (Layout f) -> c1 (Layout y))
-> c0 (Layout s) -> Payload s (WitnessField c0) -> f -> y
forall (c0 :: (Type -> Type) -> Type) s f
       (c1 :: (Type -> Type) -> Type) y a (p :: Type -> Type)
       (i :: Type -> Type).
(CompilesWith c0 s f, RestoresFrom c1 y,
 c1 ~ ArithmeticCircuit a p i) =>
(c0 (Layout f) -> c1 (Layout y))
-> c0 (Layout s) -> Payload s (WitnessField c0) -> f -> y
compileInternal c0 (Layout f) -> c1 (Layout y)
outputTransform
    ((forall x. p x -> i x -> Layout s x)
-> ArithmeticCircuit a p i (Layout s)
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
(Arithmetic a, Representable p, Representable i, Traversable o,
 Binary a, Binary (Rep p), Binary (Rep i), Ord (Rep i)) =>
(forall x. p x -> i x -> o x) -> ArithmeticCircuit a p i o
naturalCircuit ((forall x. p x -> i x -> Layout s x)
 -> ArithmeticCircuit a p i (Layout s))
-> (forall x. p x -> i x -> Layout s x)
-> ArithmeticCircuit a p i (Layout s)
forall a b. (a -> b) -> a -> b
$ \p x
p i x
i -> (Payload s x, Layout s x) -> Layout s x
forall a b. (a, b) -> b
snd (p x -> i x -> (Payload s x, Layout s x)
forall x. p x -> i x -> (Payload s x, Layout s x)
inputTransform p x
p i x
i))
    ((forall x. p x -> i x -> Payload s x)
-> Payload s (WitnessF a (WitVar p i))
forall (p :: Type -> Type) (i :: Type -> Type) (o :: Type -> Type)
       a.
(Representable p, Representable i) =>
(forall x. p x -> i x -> o x) -> o (WitnessF a (WitVar p i))
inputPayload ((forall x. p x -> i x -> Payload s x)
 -> Payload s (WitnessF a (WitVar p i)))
-> (forall x. p x -> i x -> Payload s x)
-> Payload s (WitnessF a (WitVar p i))
forall a b. (a -> b) -> a -> b
$ \p x
p i x
i -> (Payload s x, Layout s x) -> Payload s x
forall a b. (a, b) -> a
fst (p x -> i x -> (Payload s x, Layout s x)
forall x. p x -> i x -> (Payload s x, Layout s x)
inputTransform p x
p i x
i))

-- | @compile f@ compiles a function @f@ into an optimized arithmetic circuit
-- packed inside a suitable 'SymbolicData'.
compile :: forall a y f c s.
  ( CompilesWith c s f, RestoresFrom c y, Layout y ~ Layout f
  , c ~ ArithmeticCircuit a (Payload s) (Layout s))
  => f -> y
compile :: forall a y f (c :: (Type -> Type) -> Type) s.
(CompilesWith c s f, RestoresFrom c y, Layout y ~ Layout f,
 c ~ ArithmeticCircuit a (Payload s) (Layout s)) =>
f -> y
compile = (ArithmeticCircuit a (Payload s) (Layout s) (Layout f)
 -> ArithmeticCircuit a (Payload s) (Layout s) (Layout y))
-> ArithmeticCircuit a (Payload s) (Layout s) (Layout s)
-> Payload
     s (WitnessField (ArithmeticCircuit a (Payload s) (Layout s)))
-> f
-> y
forall (c0 :: (Type -> Type) -> Type) s f
       (c1 :: (Type -> Type) -> Type) y a (p :: Type -> Type)
       (i :: Type -> Type).
(CompilesWith c0 s f, RestoresFrom c1 y,
 c1 ~ ArithmeticCircuit a p i) =>
(c0 (Layout f) -> c1 (Layout y))
-> c0 (Layout s) -> Payload s (WitnessField c0) -> f -> y
compileInternal ArithmeticCircuit a (Payload s) (Layout s) (Layout f)
-> ArithmeticCircuit a (Payload s) (Layout s) (Layout y)
ArithmeticCircuit a (Payload s) (Layout s) (Layout f)
-> ArithmeticCircuit a (Payload s) (Layout s) (Layout f)
forall a. a -> a
id ArithmeticCircuit a (Payload s) (Layout s) (Layout s)
forall (i :: Type -> Type) a (p :: Type -> Type).
Representable i =>
ArithmeticCircuit a p i i
idCircuit ((forall x. Payload s x -> Layout s x -> Payload s x)
-> Payload s (WitnessF a (WitVar (Payload s) (Layout s)))
forall (p :: Type -> Type) (i :: Type -> Type) (o :: Type -> Type)
       a.
(Representable p, Representable i) =>
(forall x. p x -> i x -> o x) -> o (WitnessF a (WitVar p i))
inputPayload Payload s x -> Layout s x -> Payload s x
forall x. Payload s x -> Layout s x -> Payload s x
forall a b. a -> b -> a
const)

-- | Compiles a function `f` into an arithmetic circuit. Writes the result to a file.
compileIO ::
  forall a c p f s l .
  ( c ~ ArithmeticCircuit a p l
  , FromJSON a
  , ToJSON a
  , ToJSONKey a
  , SymbolicData f
  , Context f ~ c
  , Support f ~ s
  , ToJSON (Layout f (Var a l))
  , SymbolicInput s
  , Context s ~ c
  , Layout s ~ l
  , Payload s ~ p
  , FromJSON (Rep l)
  , ToJSON (Rep l)
  ) => FilePath -> f -> IO ()
compileIO :: forall a (c :: (Type -> Type) -> Type) (p :: Type -> Type) f s
       (l :: Type -> Type).
(c ~ ArithmeticCircuit a p l, FromJSON a, ToJSON a, ToJSONKey a,
 SymbolicData f, Context f ~ c, Support f ~ s,
 ToJSON (Layout f (Var a l)), SymbolicInput s, Context s ~ c,
 Layout s ~ l, Payload s ~ p, FromJSON (Rep l), ToJSON (Rep l)) =>
FilePath -> f -> IO ()
compileIO FilePath
scriptFile f
f = do
    let ac :: c (Layout f)
ac = f -> ArithmeticCircuit a p l (Layout f)
forall a y f (c :: (Type -> Type) -> Type) s.
(CompilesWith c s f, RestoresFrom c y, Layout y ~ Layout f,
 c ~ ArithmeticCircuit a (Payload s) (Layout s)) =>
f -> y
compile f
f :: c (Layout f)

    FilePath -> IO ()
putStrLn FilePath
"\nCompiling the script...\n"

    FilePath -> IO ()
putStrLn (FilePath -> IO ()) -> FilePath -> IO ()
forall a b. (a -> b) -> a -> b
$ FilePath
"Number of constraints: " FilePath -> FilePath -> FilePath
forall a. [a] -> [a] -> [a]
++ Natural -> FilePath
forall a. Show a => a -> FilePath
show (ArithmeticCircuit a p l (Layout f) -> Natural
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> Natural
acSizeN c (Layout f)
ArithmeticCircuit a p l (Layout f)
ac)
    FilePath -> IO ()
putStrLn (FilePath -> IO ()) -> FilePath -> IO ()
forall a b. (a -> b) -> a -> b
$ FilePath
"Number of variables: " FilePath -> FilePath -> FilePath
forall a. [a] -> [a] -> [a]
++ Natural -> FilePath
forall a. Show a => a -> FilePath
show (ArithmeticCircuit a p l (Layout f) -> Natural
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> Natural
acSizeM c (Layout f)
ArithmeticCircuit a p l (Layout f)
ac)
    FilePath -> ArithmeticCircuit a p l (Layout f) -> IO ()
forall a. ToJSON a => FilePath -> a -> IO ()
writeFileJSON FilePath
scriptFile c (Layout f)
ArithmeticCircuit a p l (Layout f)
ac
    FilePath -> IO ()
putStrLn (FilePath -> IO ()) -> FilePath -> IO ()
forall a b. (a -> b) -> a -> b
$ FilePath
"Script saved: " FilePath -> FilePath -> FilePath
forall a. [a] -> [a] -> [a]
++ FilePath
scriptFile