{-# LANGUAGE TypeOperators #-}

module ZkFold.Symbolic.Compiler.ArithmeticCircuit (
        ArithmeticCircuit,
        Constraint,
        Var,
        witnessGenerator,
        -- high-level functions
        optimize,
        desugarRanges,
        emptyCircuit,
        idCircuit,
        naturalCircuit,
        inputPayload,
        guessOutput,
        -- low-level functions
        eval,
        eval1,
        exec,
        exec1,
        -- information about the system
        acSizeN,
        acSizeM,
        acSizeR,
        acSystem,
        acValue,
        acPrint,
        -- Variable mapping functions
        hlmap,
        hpmap,
        mapVarArithmeticCircuit,
        -- Arithmetization type fields
        acWitness,
        acInput,
        acOutput,
        -- Testing functions
        checkCircuit,
        checkClosedCircuit,
        isConstantInput
    ) where

import           Control.DeepSeq                                         (NFData)
import           Control.Monad                                           (foldM)
import           Control.Monad.State                                     (execState)
import           Data.Binary                                             (Binary)
import           Data.Bool                                               (bool)
import           Data.Foldable                                           (for_)
import           Data.Functor.Rep                                        (Representable (..), mzipRep)
import           Data.Map                                                hiding (drop, foldl, foldr, map, null, splitAt,
                                                                          take)
import qualified Data.Map.Monoidal                                       as M
import qualified Data.Set                                                as S
import           Data.Void                                               (absurd)
import           GHC.Generics                                            (U1 (..), (:*:))
import           Numeric.Natural                                         (Natural)
import           Prelude                                                 hiding (Num (..), drop, length, product,
                                                                          splitAt, sum, take, (!!), (^))
import           Test.QuickCheck                                         (Arbitrary, Property, arbitrary, conjoin,
                                                                          property, withMaxSuccess, (===))
import           Text.Pretty.Simple                                      (pPrint)

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Base.Algebra.Polynomials.Multivariate            (evalMonomial, evalPolynomial)
import           ZkFold.Base.Data.HFunctor                               (hmap)
import           ZkFold.Base.Data.Product                                (fstP, sndP)
import           ZkFold.Prelude                                          (length)
import           ZkFold.Symbolic.Class                                   (fromCircuit2F)
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit.Instance     ()
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit.Lookup
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit.Map
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit.Optimization
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit.Var          (toVar)
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit.Witness      (WitnessF)
import           ZkFold.Symbolic.Data.Combinators                        (expansion)
import           ZkFold.Symbolic.MonadCircuit                            (MonadCircuit (..))

--------------------------------- High-level functions --------------------------------

desugarRange :: (Arithmetic a, MonadCircuit i a w m) => i -> (a, a) -> m ()
desugarRange :: forall a i w (m :: Type -> Type).
(Arithmetic a, MonadCircuit i a w m) =>
i -> (a, a) -> m ()
desugarRange i
i (a
a, a
b)
  | a
b a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a -> a
forall a. AdditiveGroup a => a -> a
negate a
forall a. MultiplicativeMonoid a => a
one = () -> m ()
forall a. a -> m a
forall (m :: Type -> Type) a. Monad m => a -> m a
return ()
  | Bool
otherwise = do
    let bs :: Bits Natural
bs = Natural -> Bits Natural
forall a. BinaryExpansion a => a -> Bits a
binaryExpansion (a -> Const a
forall a. ToConstant a => a -> Const a
toConstant a
b)
        as :: Bits Natural
as = Natural -> Bits Natural
forall a. BinaryExpansion a => a -> Bits a
binaryExpansion (a -> Const a
forall a. ToConstant a => a -> Const a
toConstant a
a)
    [i]
isb <- Natural -> i -> m [i]
forall i a w (m :: Type -> Type).
(MonadCircuit i a w m, Arithmetic a) =>
Natural -> i -> m [i]
expansion ([Natural] -> Natural
forall (t :: Type -> Type) a. Foldable t => t a -> Natural
length [Natural]
Bits Natural
bs) i
i
    case ((Natural, i) -> Bool) -> [(Natural, i)] -> [(Natural, i)]
forall a. (a -> Bool) -> [a] -> [a]
dropWhile ((Natural -> Natural -> Bool
forall a. Eq a => a -> a -> Bool
== Natural
forall a. MultiplicativeMonoid a => a
one) (Natural -> Bool)
-> ((Natural, i) -> Natural) -> (Natural, i) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Natural, i) -> Natural
forall a b. (a, b) -> a
fst) ([Natural] -> [i] -> [(Natural, i)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Natural]
Bits Natural
bs [i]
isb) of
      [] -> () -> m ()
forall a. a -> m a
forall (m :: Type -> Type) a. Monad m => a -> m a
return ()
      ((Natural
_, i
k0):[(Natural, i)]
ds) -> do
        i
z <- ClosedPoly i a -> m i
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
ClosedPoly var a -> m var
newAssigned ((i -> x) -> x
forall a. MultiplicativeMonoid a => a
one ((i -> x) -> x) -> ((i -> x) -> x) -> (i -> x) -> x
forall a. AdditiveGroup a => a -> a -> a
- ((i -> x) -> i -> x
forall a b. (a -> b) -> a -> b
$ i
k0))
        i
ge <- (i -> (Natural, i) -> m i) -> i -> [(Natural, i)] -> m i
forall (t :: Type -> Type) (m :: Type -> Type) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (\i
j (Natural
c, i
k) -> ClosedPoly i a -> m i
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
ClosedPoly var a -> m var
newAssigned (ClosedPoly i a -> m i) -> ClosedPoly i a -> m i
forall a b. (a -> b) -> a -> b
$ i -> Natural -> i -> (i -> x) -> x
forall {a} {b} {p}.
(Eq a, AdditiveGroup b, MultiplicativeMonoid b,
 AdditiveMonoid a) =>
p -> a -> p -> (p -> b) -> b
forceGE i
j Natural
c i
k) i
z [(Natural, i)]
ds
        ClosedPoly i a -> m ()
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
ClosedPoly var a -> m ()
constraint (((i -> x) -> i -> x
forall a b. (a -> b) -> a -> b
$ i
ge) ((i -> x) -> x) -> ((i -> x) -> x) -> (i -> x) -> x
forall a. AdditiveGroup a => a -> a -> a
- (i -> x) -> x
forall a. MultiplicativeMonoid a => a
one)
    [i]
isa <- Natural -> i -> m [i]
forall i a w (m :: Type -> Type).
(MonadCircuit i a w m, Arithmetic a) =>
Natural -> i -> m [i]
expansion ([Natural] -> Natural
forall (t :: Type -> Type) a. Foldable t => t a -> Natural
length [Natural]
Bits Natural
as) i
i
    case ((Natural, i) -> Bool) -> [(Natural, i)] -> [(Natural, i)]
forall a. (a -> Bool) -> [a] -> [a]
dropWhile ((Natural -> Natural -> Bool
forall a. Eq a => a -> a -> Bool
== Natural
forall a. AdditiveMonoid a => a
zero) (Natural -> Bool)
-> ((Natural, i) -> Natural) -> (Natural, i) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Natural, i) -> Natural
forall a b. (a, b) -> a
fst) ([Natural] -> [i] -> [(Natural, i)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Natural]
Bits Natural
bs [i]
isa) of
      [] -> () -> m ()
forall a. a -> m a
forall (m :: Type -> Type) a. Monad m => a -> m a
return ()
      ((Natural
_, i
k0):[(Natural, i)]
ds) -> do
        i
z <- ClosedPoly i a -> m i
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
ClosedPoly var a -> m var
newAssigned ((i -> x) -> x
forall a. MultiplicativeMonoid a => a
one ((i -> x) -> x) -> ((i -> x) -> x) -> (i -> x) -> x
forall a. AdditiveGroup a => a -> a -> a
- ((i -> x) -> i -> x
forall a b. (a -> b) -> a -> b
$ i
k0))
        i
ge <- (i -> (Natural, i) -> m i) -> i -> [(Natural, i)] -> m i
forall (t :: Type -> Type) (m :: Type -> Type) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (\i
j (Natural
c, i
k) -> ClosedPoly i a -> m i
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
ClosedPoly var a -> m var
newAssigned (ClosedPoly i a -> m i) -> ClosedPoly i a -> m i
forall a b. (a -> b) -> a -> b
$ i -> Natural -> i -> (i -> x) -> x
forall {a} {b} {p}.
(Eq a, AdditiveGroup b, MultiplicativeMonoid a,
 MultiplicativeMonoid b) =>
p -> a -> p -> (p -> b) -> b
forceLE i
j Natural
c i
k) i
z [(Natural, i)]
ds
        ClosedPoly i a -> m ()
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
ClosedPoly var a -> m ()
constraint (((i -> x) -> i -> x
forall a b. (a -> b) -> a -> b
$ i
ge) ((i -> x) -> x) -> ((i -> x) -> x) -> (i -> x) -> x
forall a. AdditiveGroup a => a -> a -> a
- (i -> x) -> x
forall a. MultiplicativeMonoid a => a
one)
  where forceGE :: p -> a -> p -> (p -> b) -> b
forceGE p
j a
c p
k
          | a
c a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
forall a. AdditiveMonoid a => a
zero = ((p -> b) -> p -> b
forall a b. (a -> b) -> a -> b
$ p
j) ((p -> b) -> b) -> ((p -> b) -> b) -> (p -> b) -> b
forall a. MultiplicativeSemigroup a => a -> a -> a
* ((p -> b) -> b
forall a. MultiplicativeMonoid a => a
one ((p -> b) -> b) -> ((p -> b) -> b) -> (p -> b) -> b
forall a. AdditiveGroup a => a -> a -> a
- ((p -> b) -> p -> b
forall a b. (a -> b) -> a -> b
$ p
k))
          | Bool
otherwise = (p -> b) -> b
forall a. MultiplicativeMonoid a => a
one ((p -> b) -> b) -> ((p -> b) -> b) -> (p -> b) -> b
forall a. AdditiveSemigroup a => a -> a -> a
+ ((p -> b) -> p -> b
forall a b. (a -> b) -> a -> b
$ p
k) ((p -> b) -> b) -> ((p -> b) -> b) -> (p -> b) -> b
forall a. MultiplicativeSemigroup a => a -> a -> a
* (((p -> b) -> p -> b
forall a b. (a -> b) -> a -> b
$ p
j) ((p -> b) -> b) -> ((p -> b) -> b) -> (p -> b) -> b
forall a. AdditiveGroup a => a -> a -> a
- (p -> b) -> b
forall a. MultiplicativeMonoid a => a
one)

        forceLE :: p -> a -> p -> (p -> b) -> b
forceLE p
j a
c p
k
          | a
c a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
forall a. MultiplicativeMonoid a => a
one = ((p -> b) -> p -> b
forall a b. (a -> b) -> a -> b
$ p
j) ((p -> b) -> b) -> ((p -> b) -> b) -> (p -> b) -> b
forall a. MultiplicativeSemigroup a => a -> a -> a
* ((p -> b) -> b
forall a. MultiplicativeMonoid a => a
one ((p -> b) -> b) -> ((p -> b) -> b) -> (p -> b) -> b
forall a. AdditiveGroup a => a -> a -> a
- ((p -> b) -> p -> b
forall a b. (a -> b) -> a -> b
$ p
k))
          | Bool
otherwise = (p -> b) -> b
forall a. MultiplicativeMonoid a => a
one ((p -> b) -> b) -> ((p -> b) -> b) -> (p -> b) -> b
forall a. AdditiveSemigroup a => a -> a -> a
+ ((p -> b) -> p -> b
forall a b. (a -> b) -> a -> b
$ p
k) ((p -> b) -> b) -> ((p -> b) -> b) -> (p -> b) -> b
forall a. MultiplicativeSemigroup a => a -> a -> a
* (((p -> b) -> p -> b
forall a b. (a -> b) -> a -> b
$ p
j) ((p -> b) -> b) -> ((p -> b) -> b) -> (p -> b) -> b
forall a. AdditiveGroup a => a -> a -> a
- (p -> b) -> b
forall a. MultiplicativeMonoid a => a
one)

-- | Desugars range constraints into polynomial constraints
desugarRanges ::
  (Arithmetic a, Binary a, Binary (Rep p), Binary (Rep i), Ord (Rep i)) =>
  ArithmeticCircuit a p i o -> ArithmeticCircuit a p i o
desugarRanges :: forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
(Arithmetic a, Binary a, Binary (Rep p), Binary (Rep i),
 Ord (Rep i)) =>
ArithmeticCircuit a p i o -> ArithmeticCircuit a p i o
desugarRanges ArithmeticCircuit a p i o
c =
  let r' :: ArithmeticCircuit a p i U1
r' = (State (ArithmeticCircuit a p i U1) [()]
 -> ArithmeticCircuit a p i U1 -> ArithmeticCircuit a p i U1)
-> ArithmeticCircuit a p i U1
-> State (ArithmeticCircuit a p i U1) [()]
-> ArithmeticCircuit a p i U1
forall a b c. (a -> b -> c) -> b -> a -> c
flip State (ArithmeticCircuit a p i U1) [()]
-> ArithmeticCircuit a p i U1 -> ArithmeticCircuit a p i U1
forall s a. State s a -> s -> s
execState ArithmeticCircuit a p i o
c {acOutput = U1} (State (ArithmeticCircuit a p i U1) [()]
 -> ArithmeticCircuit a p i U1)
-> ([(Var a i, (a, a))] -> State (ArithmeticCircuit a p i U1) [()])
-> [(Var a i, (a, a))]
-> ArithmeticCircuit a p i U1
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Var a i, (a, a))
 -> StateT (ArithmeticCircuit a p i U1) Identity ())
-> [(Var a i, (a, a))] -> State (ArithmeticCircuit a p i U1) [()]
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: Type -> Type) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ((Var a i
 -> (a, a) -> StateT (ArithmeticCircuit a p i U1) Identity ())
-> (Var a i, (a, a))
-> StateT (ArithmeticCircuit a p i U1) Identity ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Var a i
-> (a, a) -> StateT (ArithmeticCircuit a p i U1) Identity ()
forall a i w (m :: Type -> Type).
(Arithmetic a, MonadCircuit i a w m) =>
i -> (a, a) -> m ()
desugarRange) ([(Var a i, (a, a))] -> ArithmeticCircuit a p i U1)
-> [(Var a i, (a, a))] -> ArithmeticCircuit a p i U1
forall a b. (a -> b) -> a -> b
$
          [([Var a i] -> Var a i
forall a. HasCallStack => [a] -> a
head ([Var a i] -> Var a i) -> [Var a i] -> Var a i
forall a b. (a -> b) -> a -> b
$ (SysVar i -> Var a i) -> [SysVar i] -> [Var a i]
forall a b. (a -> b) -> [a] -> [b]
map SysVar i -> Var a i
forall a (i :: Type -> Type). Semiring a => SysVar i -> Var a i
toVar [SysVar i]
v, (a, a)
k) | (Set (a, a)
k', Set [SysVar i]
s) <- MonoidalMap (Set (a, a)) (Set [SysVar i])
-> [(Set (a, a), Set [SysVar i])]
forall k a. MonoidalMap k a -> [(k, a)]
M.toList MonoidalMap (Set (a, a)) (Set [SysVar i])
rm, [SysVar i]
v <- Set [SysVar i] -> [[SysVar i]]
forall a. Set a -> [a]
S.toList Set [SysVar i]
s, (a, a)
k <- Set (a, a) -> [(a, a)]
forall a. Set a -> [a]
S.toList Set (a, a)
k']
      rm :: MonoidalMap (Set (a, a)) (Set [SysVar i])
rm = (LookupType a -> Set (a, a))
-> MonoidalMap (LookupType a) (Set [SysVar i])
-> MonoidalMap (Set (a, a)) (Set [SysVar i])
forall k1 k2 a.
Ord k2 =>
(k1 -> k2) -> MonoidalMap k1 a -> MonoidalMap k2 a
M.mapKeys (\LookupType a
k -> Set (a, a) -> Set (a, a) -> Bool -> Set (a, a)
forall a. a -> a -> Bool -> a
bool ([Char] -> Set (a, a)
forall a. HasCallStack => [Char] -> a
error [Char]
"There should only be a range-lookups here") (LookupType a -> Set (a, a)
forall a. LookupType a -> Set (a, a)
fromRange LookupType a
k) (LookupType a -> Bool
forall a. LookupType a -> Bool
isRange LookupType a
k)) MonoidalMap (LookupType a) (Set [SysVar i])
rm'
      (MonoidalMap (LookupType a) (Set [SysVar i])
rm', MonoidalMap (LookupType a) (Set [SysVar i])
tm) = (LookupType a -> Set [SysVar i] -> Bool)
-> MonoidalMap (LookupType a) (Set [SysVar i])
-> (MonoidalMap (LookupType a) (Set [SysVar i]),
    MonoidalMap (LookupType a) (Set [SysVar i]))
forall k a.
(k -> a -> Bool)
-> MonoidalMap k a -> (MonoidalMap k a, MonoidalMap k a)
M.partitionWithKey (\LookupType a
k Set [SysVar i]
_ -> LookupType a -> Bool
forall a. LookupType a -> Bool
isRange LookupType a
k) (ArithmeticCircuit a p i o
-> MonoidalMap (LookupType a) (Set [SysVar i])
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o
-> MonoidalMap (LookupType a) (Set [SysVar i])
acLookup ArithmeticCircuit a p i o
c)
  in ArithmeticCircuit a p i U1
r' { acLookup = tm, acOutput = acOutput c }

-- | Payload of an input to arithmetic circuit.
-- To be used as an argument to 'compileWith'.
inputPayload ::
  (Representable p, Representable i) =>
  (forall x. p x -> i x -> o x) -> o (WitnessF a (WitVar p i))
inputPayload :: forall (p :: Type -> Type) (i :: Type -> Type) (o :: Type -> Type)
       a.
(Representable p, Representable i) =>
(forall x. p x -> i x -> o x) -> o (WitnessF a (WitVar p i))
inputPayload forall x. p x -> i x -> o x
f =
  p (WitnessF a (WitVar p i))
-> i (WitnessF a (WitVar p i)) -> o (WitnessF a (WitVar p i))
forall x. p x -> i x -> o x
f ((Rep p -> WitnessF a (WitVar p i)) -> p (WitnessF a (WitVar p i))
forall a. (Rep p -> a) -> p a
forall (f :: Type -> Type) a.
Representable f =>
(Rep f -> a) -> f a
tabulate ((Rep p -> WitnessF a (WitVar p i)) -> p (WitnessF a (WitVar p i)))
-> (Rep p -> WitnessF a (WitVar p i))
-> p (WitnessF a (WitVar p i))
forall a b. (a -> b) -> a -> b
$ WitVar p i -> WitnessF a (WitVar p i)
forall a. a -> WitnessF a a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (WitVar p i -> WitnessF a (WitVar p i))
-> (Rep p -> WitVar p i) -> Rep p -> WitnessF a (WitVar p i)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rep p -> WitVar p i
forall (p :: Type -> Type) (i :: Type -> Type). Rep p -> WitVar p i
WExVar) ((Rep i -> WitnessF a (WitVar p i)) -> i (WitnessF a (WitVar p i))
forall a. (Rep i -> a) -> i a
forall (f :: Type -> Type) a.
Representable f =>
(Rep f -> a) -> f a
tabulate ((Rep i -> WitnessF a (WitVar p i)) -> i (WitnessF a (WitVar p i)))
-> (Rep i -> WitnessF a (WitVar p i))
-> i (WitnessF a (WitVar p i))
forall a b. (a -> b) -> a -> b
$ WitVar p i -> WitnessF a (WitVar p i)
forall a. a -> WitnessF a a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (WitVar p i -> WitnessF a (WitVar p i))
-> (Rep i -> WitVar p i) -> Rep i -> WitnessF a (WitVar p i)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SysVar i -> WitVar p i
forall (p :: Type -> Type) (i :: Type -> Type).
SysVar i -> WitVar p i
WSysVar (SysVar i -> WitVar p i)
-> (Rep i -> SysVar i) -> Rep i -> WitVar p i
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rep i -> SysVar i
forall (i :: Type -> Type). Rep i -> SysVar i
InVar)

guessOutput ::
  (Arithmetic a, Binary a, Binary (Rep p), Binary (Rep i), Binary (Rep o)) =>
  (Ord (Rep i), Ord (Rep o), NFData (Rep i), NFData (Rep o)) =>
  (Representable i, Representable o, Foldable o) =>
  ArithmeticCircuit a p i o -> ArithmeticCircuit a p (i :*: o) U1
guessOutput :: forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
(Arithmetic a, Binary a, Binary (Rep p), Binary (Rep i),
 Binary (Rep o), Ord (Rep i), Ord (Rep o), NFData (Rep i),
 NFData (Rep o), Representable i, Representable o, Foldable o) =>
ArithmeticCircuit a p i o -> ArithmeticCircuit a p (i :*: o) U1
guessOutput ArithmeticCircuit a p i o
c = ArithmeticCircuit a p (i :*: o) o
-> ArithmeticCircuit a p (i :*: o) o
-> (forall {i} {m :: Type -> Type}.
    (NFData i,
     MonadCircuit
       i
       (BaseField (ArithmeticCircuit a p (i :*: o)))
       (WitnessField (ArithmeticCircuit a p (i :*: o)))
       m) =>
    FunBody '[o, o] U1 i m)
-> ArithmeticCircuit a p (i :*: o) U1
forall (c :: (Type -> Type) -> Type) (f :: Type -> Type)
       (g :: Type -> Type) (h :: Type -> Type).
Symbolic c =>
c f -> c g -> CircuitFun '[f, g] h c -> c h
fromCircuit2F ((forall x. (:*:) i o x -> i x)
-> ArithmeticCircuit a p i o -> ArithmeticCircuit a p (i :*: o) o
forall (i :: Type -> Type) (j :: Type -> Type) (o :: Type -> Type)
       a (p :: Type -> Type).
(Representable i, Representable j, Ord (Rep j), Functor o) =>
(forall x. j x -> i x)
-> ArithmeticCircuit a p i o -> ArithmeticCircuit a p j o
hlmap (:*:) i o x -> i x
forall x. (:*:) i o x -> i x
forall {k} (f :: k -> Type) (g :: k -> Type) (a :: k).
(:*:) f g a -> f a
fstP ArithmeticCircuit a p i o
c) ((forall a. (:*:) i o a -> o a)
-> ArithmeticCircuit a p (i :*: o) (i :*: o)
-> ArithmeticCircuit a p (i :*: o) o
forall {k} (c :: (k -> Type) -> Type) (f :: k -> Type)
       (g :: k -> Type).
HFunctor c =>
(forall (a :: k). f a -> g a) -> c f -> c g
forall (f :: Type -> Type) (g :: Type -> Type).
(forall a. f a -> g a)
-> ArithmeticCircuit a p (i :*: o) f
-> ArithmeticCircuit a p (i :*: o) g
hmap (:*:) i o a -> o a
forall a. (:*:) i o a -> o a
forall {k} (f :: k -> Type) (g :: k -> Type) (a :: k).
(:*:) f g a -> g a
sndP ArithmeticCircuit a p (i :*: o) (i :*: o)
forall (i :: Type -> Type) a (p :: Type -> Type).
(Representable i, Semiring a) =>
ArithmeticCircuit a p i i
idCircuit) ((forall {i} {m :: Type -> Type}.
  (NFData i,
   MonadCircuit
     i
     (BaseField (ArithmeticCircuit a p (i :*: o)))
     (WitnessField (ArithmeticCircuit a p (i :*: o)))
     m) =>
  FunBody '[o, o] U1 i m)
 -> ArithmeticCircuit a p (i :*: o) U1)
-> (forall {i} {m :: Type -> Type}.
    (NFData i,
     MonadCircuit
       i
       (BaseField (ArithmeticCircuit a p (i :*: o)))
       (WitnessField (ArithmeticCircuit a p (i :*: o)))
       m) =>
    FunBody '[o, o] U1 i m)
-> ArithmeticCircuit a p (i :*: o) U1
forall a b. (a -> b) -> a -> b
$ \o i
o o i
o' -> do
  o (i, i) -> ((i, i) -> m ()) -> m ()
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ (o i -> o i -> o (i, i)
forall (f :: Type -> Type) a b.
Representable f =>
f a -> f b -> f (a, b)
mzipRep o i
o o i
o') (((i, i) -> m ()) -> m ()) -> ((i, i) -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \(i
i, i
j) -> ClosedPoly i a -> m ()
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
ClosedPoly var a -> m ()
constraint (\i -> x
x -> i -> x
x i
i x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- i -> x
x i
j)
  U1 i -> m (U1 i)
forall a. a -> m a
forall (m :: Type -> Type) a. Monad m => a -> m a
return U1 i
forall k (p :: k). U1 p
U1

----------------------------------- Information -----------------------------------

-- | Calculates the number of constraints in the system.
acSizeN :: ArithmeticCircuit a p i o -> Natural
acSizeN :: forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> Natural
acSizeN = Map ByteString (Constraint a i) -> Natural
forall (t :: Type -> Type) a. Foldable t => t a -> Natural
length (Map ByteString (Constraint a i) -> Natural)
-> (ArithmeticCircuit a p i o -> Map ByteString (Constraint a i))
-> ArithmeticCircuit a p i o
-> Natural
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArithmeticCircuit a p i o -> Map ByteString (Constraint a i)
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> Map ByteString (Constraint a i)
acSystem

-- | Calculates the number of variables in the system.
acSizeM :: ArithmeticCircuit a p i o -> Natural
acSizeM :: forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> Natural
acSizeM = Map ByteString (CircuitWitness a p i) -> Natural
forall (t :: Type -> Type) a. Foldable t => t a -> Natural
length (Map ByteString (CircuitWitness a p i) -> Natural)
-> (ArithmeticCircuit a p i o
    -> Map ByteString (CircuitWitness a p i))
-> ArithmeticCircuit a p i o
-> Natural
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArithmeticCircuit a p i o -> Map ByteString (CircuitWitness a p i)
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> Map ByteString (CircuitWitness a p i)
acWitness

-- | Calculates the number of range lookups in the system.
acSizeR :: ArithmeticCircuit a p i o -> Natural
acSizeR :: forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> Natural
acSizeR = [Natural] -> Natural
forall (t :: Type -> Type) a.
(Foldable t, AdditiveMonoid a) =>
t a -> a
sum ([Natural] -> Natural)
-> (ArithmeticCircuit a p i o -> [Natural])
-> ArithmeticCircuit a p i o
-> Natural
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Set [SysVar i] -> Natural) -> [Set [SysVar i]] -> [Natural]
forall a b. (a -> b) -> [a] -> [b]
map Set [SysVar i] -> Natural
forall (t :: Type -> Type) a. Foldable t => t a -> Natural
length ([Set [SysVar i]] -> [Natural])
-> (ArithmeticCircuit a p i o -> [Set [SysVar i]])
-> ArithmeticCircuit a p i o
-> [Natural]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MonoidalMap (LookupType a) (Set [SysVar i]) -> [Set [SysVar i]]
forall k a. MonoidalMap k a -> [a]
M.elems (MonoidalMap (LookupType a) (Set [SysVar i]) -> [Set [SysVar i]])
-> (ArithmeticCircuit a p i o
    -> MonoidalMap (LookupType a) (Set [SysVar i]))
-> ArithmeticCircuit a p i o
-> [Set [SysVar i]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArithmeticCircuit a p i o
-> MonoidalMap (LookupType a) (Set [SysVar i])
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o
-> MonoidalMap (LookupType a) (Set [SysVar i])
acLookup

acValue ::
  (Arithmetic a, Binary a, Functor o) => ArithmeticCircuit a U1 U1 o -> o a
acValue :: forall a (o :: Type -> Type).
(Arithmetic a, Binary a, Functor o) =>
ArithmeticCircuit a U1 U1 o -> o a
acValue = ArithmeticCircuit a U1 U1 o -> o a
forall a (o :: Type -> Type).
(Arithmetic a, Binary a, Functor o) =>
ArithmeticCircuit a U1 U1 o -> o a
exec

-- | Prints the constraint system, the witness, and the output.
--
-- TODO: Move this elsewhere (?)
-- TODO: Check that all arguments have been applied.
acPrint ::
  (Arithmetic a, Binary a, Show a) =>
  (Show (o (Var a U1)), Show (o a), Functor o) =>
  ArithmeticCircuit a U1 U1 o -> IO ()
acPrint :: forall a (o :: Type -> Type).
(Arithmetic a, Binary a, Show a, Show (o (Var a U1)), Show (o a),
 Functor o) =>
ArithmeticCircuit a U1 U1 o -> IO ()
acPrint ArithmeticCircuit a U1 U1 o
ac = do
    let m :: [Constraint a U1]
m = Map ByteString (Constraint a U1) -> [Constraint a U1]
forall k a. Map k a -> [a]
elems (ArithmeticCircuit a U1 U1 o -> Map ByteString (Constraint a U1)
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> Map ByteString (Constraint a i)
acSystem ArithmeticCircuit a U1 U1 o
ac)
        w :: Map NewVar a
w = ArithmeticCircuit a U1 U1 o -> U1 a -> U1 a -> Map NewVar a
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
(Arithmetic a, Binary a, Representable p, Representable i) =>
ArithmeticCircuit a p i o -> p a -> i a -> Map NewVar a
witnessGenerator ArithmeticCircuit a U1 U1 o
ac U1 a
forall k (p :: k). U1 p
U1 U1 a
forall k (p :: k). U1 p
U1
        v :: o a
v = ArithmeticCircuit a U1 U1 o -> o a
forall a (o :: Type -> Type).
(Arithmetic a, Binary a, Functor o) =>
ArithmeticCircuit a U1 U1 o -> o a
acValue ArithmeticCircuit a U1 U1 o
ac
        o :: o (Var a U1)
o = ArithmeticCircuit a U1 U1 o -> o (Var a U1)
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> o (Var a i)
acOutput ArithmeticCircuit a U1 U1 o
ac
    [Char] -> IO ()
putStr [Char]
"System size: "
    Natural -> IO ()
forall (m :: Type -> Type) a. (MonadIO m, Show a) => a -> m ()
pPrint (Natural -> IO ()) -> Natural -> IO ()
forall a b. (a -> b) -> a -> b
$ ArithmeticCircuit a U1 U1 o -> Natural
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> Natural
acSizeN ArithmeticCircuit a U1 U1 o
ac
    [Char] -> IO ()
putStr [Char]
"Variable size: "
    Natural -> IO ()
forall (m :: Type -> Type) a. (MonadIO m, Show a) => a -> m ()
pPrint (Natural -> IO ()) -> Natural -> IO ()
forall a b. (a -> b) -> a -> b
$ ArithmeticCircuit a U1 U1 o -> Natural
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> Natural
acSizeM ArithmeticCircuit a U1 U1 o
ac
    [Char] -> IO ()
putStr [Char]
"Matrices: "
    [Constraint a U1] -> IO ()
forall (m :: Type -> Type) a. (MonadIO m, Show a) => a -> m ()
pPrint [Constraint a U1]
m
    [Char] -> IO ()
putStr [Char]
"Witness: "
    Map NewVar a -> IO ()
forall (m :: Type -> Type) a. (MonadIO m, Show a) => a -> m ()
pPrint Map NewVar a
w
    [Char] -> IO ()
putStr [Char]
"Output: "
    o (Var a U1) -> IO ()
forall (m :: Type -> Type) a. (MonadIO m, Show a) => a -> m ()
pPrint o (Var a U1)
o
    [Char] -> IO ()
putStr [Char]
"Value: "
    o a -> IO ()
forall (m :: Type -> Type) a. (MonadIO m, Show a) => a -> m ()
pPrint o a
v

---------------------------------- Testing -------------------------------------

isConstantInput ::
  ( Arithmetic a, Binary a, Show a, Representable p, Representable i
  , Show (p a), Show (i a), Arbitrary (p a), Arbitrary (i a)
  ) => ArithmeticCircuit a p i o -> Property
isConstantInput :: forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
(Arithmetic a, Binary a, Show a, Representable p, Representable i,
 Show (p a), Show (i a), Arbitrary (p a), Arbitrary (i a)) =>
ArithmeticCircuit a p i o -> Property
isConstantInput ArithmeticCircuit a p i o
c = (i a -> i a -> p a -> Property) -> Property
forall prop. Testable prop => prop -> Property
property ((i a -> i a -> p a -> Property) -> Property)
-> (i a -> i a -> p a -> Property) -> Property
forall a b. (a -> b) -> a -> b
$ \i a
x i a
y p a
p -> ArithmeticCircuit a p i o -> p a -> i a -> Map NewVar a
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
(Arithmetic a, Binary a, Representable p, Representable i) =>
ArithmeticCircuit a p i o -> p a -> i a -> Map NewVar a
witnessGenerator ArithmeticCircuit a p i o
c p a
p i a
x Map NewVar a -> Map NewVar a -> Property
forall a. (Eq a, Show a) => a -> a -> Property
=== ArithmeticCircuit a p i o -> p a -> i a -> Map NewVar a
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
(Arithmetic a, Binary a, Representable p, Representable i) =>
ArithmeticCircuit a p i o -> p a -> i a -> Map NewVar a
witnessGenerator ArithmeticCircuit a p i o
c p a
p i a
y

checkClosedCircuit
    :: forall a o
     . Arithmetic a
    => Binary a
    => Show a
    => ArithmeticCircuit a U1 U1 o
    -> Property
checkClosedCircuit :: forall a (o :: Type -> Type).
(Arithmetic a, Binary a, Show a) =>
ArithmeticCircuit a U1 U1 o -> Property
checkClosedCircuit ArithmeticCircuit a U1 U1 o
c = Int -> Property -> Property
forall prop. Testable prop => Int -> prop -> Property
withMaxSuccess Int
1 (Property -> Property) -> Property -> Property
forall a b. (a -> b) -> a -> b
$ [Property] -> Property
forall prop. Testable prop => [prop] -> Property
conjoin [ Poly a (SysVar U1) Natural -> Property
testPoly Poly a (SysVar U1) Natural
p | Poly a (SysVar U1) Natural
p <- Map ByteString (Poly a (SysVar U1) Natural)
-> [Poly a (SysVar U1) Natural]
forall k a. Map k a -> [a]
elems (ArithmeticCircuit a U1 U1 o
-> Map ByteString (Poly a (SysVar U1) Natural)
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> Map ByteString (Constraint a i)
acSystem ArithmeticCircuit a U1 U1 o
c) ]
    where
        w :: Map NewVar a
w = ArithmeticCircuit a U1 U1 o -> U1 a -> U1 a -> Map NewVar a
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
(Arithmetic a, Binary a, Representable p, Representable i) =>
ArithmeticCircuit a p i o -> p a -> i a -> Map NewVar a
witnessGenerator ArithmeticCircuit a U1 U1 o
c U1 a
forall k (p :: k). U1 p
U1 U1 a
forall k (p :: k). U1 p
U1
        testPoly :: Poly a (SysVar U1) Natural -> Property
testPoly Poly a (SysVar U1) Natural
p = ((SysVar U1 -> a) -> Mono (SysVar U1) Natural -> a)
-> (SysVar U1 -> a) -> Poly a (SysVar U1) Natural -> a
forall c i j b.
(AdditiveMonoid b, Scale c b) =>
((i -> b) -> Mono i j -> b) -> (i -> b) -> Poly c i j -> b
evalPolynomial (SysVar U1 -> a) -> Mono (SysVar U1) Natural -> a
forall i j b.
(MultiplicativeMonoid b, Exponent b j) =>
(i -> b) -> Mono i j -> b
evalMonomial SysVar U1 -> a
varF Poly a (SysVar U1) Natural
p a -> a -> Property
forall a. (Eq a, Show a) => a -> a -> Property
=== a
forall a. AdditiveMonoid a => a
zero
        varF :: SysVar U1 -> a
varF (InVar Rep U1
v)  = Void -> a
forall a. Void -> a
absurd Void
Rep U1
v
        varF (NewVar NewVar
v) = Map NewVar a
w Map NewVar a -> NewVar -> a
forall k a. Ord k => Map k a -> k -> a
! NewVar
v

checkCircuit
    :: Arbitrary (p a)
    => Arbitrary (i a)
    => Arithmetic a
    => Binary a
    => Show a
    => Representable p
    => Representable i
    => ArithmeticCircuit a p i o
    -> Property
checkCircuit :: forall (p :: Type -> Type) a (i :: Type -> Type)
       (o :: Type -> Type).
(Arbitrary (p a), Arbitrary (i a), Arithmetic a, Binary a, Show a,
 Representable p, Representable i) =>
ArithmeticCircuit a p i o -> Property
checkCircuit ArithmeticCircuit a p i o
c = [Property] -> Property
forall prop. Testable prop => [prop] -> Property
conjoin [ Gen Property -> Property
forall prop. Testable prop => prop -> Property
property (Poly a (SysVar i) Natural -> Gen Property
testPoly Poly a (SysVar i) Natural
p) | Poly a (SysVar i) Natural
p <- Map ByteString (Poly a (SysVar i) Natural)
-> [Poly a (SysVar i) Natural]
forall k a. Map k a -> [a]
elems (ArithmeticCircuit a p i o
-> Map ByteString (Poly a (SysVar i) Natural)
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> Map ByteString (Constraint a i)
acSystem ArithmeticCircuit a p i o
c) ]
    where
        testPoly :: Poly a (SysVar i) Natural -> Gen Property
testPoly Poly a (SysVar i) Natural
p = do
            i a
ins <- Gen (i a)
forall a. Arbitrary a => Gen a
arbitrary
            p a
pls <- Gen (p a)
forall a. Arbitrary a => Gen a
arbitrary
            let w :: Map NewVar a
w = ArithmeticCircuit a p i o -> p a -> i a -> Map NewVar a
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
(Arithmetic a, Binary a, Representable p, Representable i) =>
ArithmeticCircuit a p i o -> p a -> i a -> Map NewVar a
witnessGenerator ArithmeticCircuit a p i o
c p a
pls i a
ins
                varF :: SysVar i -> a
varF (InVar Rep i
v)  = i a -> Rep i -> a
forall a. i a -> Rep i -> a
forall (f :: Type -> Type) a. Representable f => f a -> Rep f -> a
index i a
ins Rep i
v
                varF (NewVar NewVar
v) = Map NewVar a
w Map NewVar a -> NewVar -> a
forall k a. Ord k => Map k a -> k -> a
! NewVar
v
            Property -> Gen Property
forall a. a -> Gen a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Property -> Gen Property) -> Property -> Gen Property
forall a b. (a -> b) -> a -> b
$ ((SysVar i -> a) -> Mono (SysVar i) Natural -> a)
-> (SysVar i -> a) -> Poly a (SysVar i) Natural -> a
forall c i j b.
(AdditiveMonoid b, Scale c b) =>
((i -> b) -> Mono i j -> b) -> (i -> b) -> Poly c i j -> b
evalPolynomial (SysVar i -> a) -> Mono (SysVar i) Natural -> a
forall i j b.
(MultiplicativeMonoid b, Exponent b j) =>
(i -> b) -> Mono i j -> b
evalMonomial SysVar i -> a
varF Poly a (SysVar i) Natural
p a -> a -> Property
forall a. (Eq a, Show a) => a -> a -> Property
=== a
forall a. AdditiveMonoid a => a
zero