{-# LANGUAGE AllowAmbiguousTypes  #-}
{-# LANGUAGE DerivingStrategies   #-}
{-# LANGUAGE UndecidableInstances #-}

{-# OPTIONS_GHC -Wno-orphans     #-}

module ZkFold.Symbolic.Compiler.ArithmeticCircuit.Instance where

import           Control.DeepSeq                                     (NFData)
import           Data.Aeson                                          hiding (Bool)
import           Data.Binary                                         (Binary)
import           Data.Bool                                           (bool)
import           Data.Functor.Rep                                    (Representable (..))
import           Data.Map                                            hiding (drop, foldl, foldl', foldr, map, null,
                                                                      splitAt, take, toList)
import           GHC.Generics                                        (Par1 (..))
import           Prelude                                             (Show, head, mempty, pure, return, show, ($), (++),
                                                                      (.), (<$>), (<))
import qualified Prelude                                             as Haskell
import           Test.QuickCheck                                     (Arbitrary (arbitrary), Gen, elements)

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Base.Algebra.Basic.Number
import           ZkFold.Base.Data.Vector                             (Vector, unsafeToVector)
import           ZkFold.Prelude                                      (genSubset, length)
import           ZkFold.Symbolic.Class
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit.Lookup   (LookupType)
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit.Var
import           ZkFold.Symbolic.Data.FieldElement                   (FieldElement (..))
import           ZkFold.Symbolic.MonadCircuit

------------------------------------- Instances -------------------------------------

  ( Arithmetic a
  , Arbitrary a
  , Binary a
  , Binary (Rep p)
  , Arbitrary (Rep i)
  , Binary (Rep i)
  , Haskell.Ord (Rep i)
  , NFData (Rep i)
  , Representable i
  , Haskell.Foldable i
  ) => Arbitrary (ArithmeticCircuit a p i Par1) where
    arbitrary :: Gen (ArithmeticCircuit a p i Par1)
arbitrary = do
        Var a i
outVar <- SysVar i -> Var a i
forall a (i :: Type -> Type). Semiring a => SysVar i -> Var a i
toVar (SysVar i -> Var a i) -> (Rep i -> SysVar i) -> Rep i -> Var a i
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rep i -> SysVar i
forall (i :: Type -> Type). Rep i -> SysVar i
InVar (Rep i -> Var a i) -> Gen (Rep i) -> Gen (Var a i)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Gen (Rep i)
forall a. Arbitrary a => Gen a
        let ac :: ArithmeticCircuit a p i Par1
ac = ArithmeticCircuit a p i U1
forall a. Monoid a => a
mempty {acOutput = Par1 outVar}
        FieldElement (ArithmeticCircuit a p i)
-> ArithmeticCircuit a p i Par1
forall (c :: (Type -> Type) -> Type). FieldElement c -> c Par1
fromFieldElement (FieldElement (ArithmeticCircuit a p i)
 -> ArithmeticCircuit a p i Par1)
-> Gen (FieldElement (ArithmeticCircuit a p i))
-> Gen (ArithmeticCircuit a p i Par1)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> FieldElement (ArithmeticCircuit a p i)
-> Natural -> Gen (FieldElement (ArithmeticCircuit a p i))
forall a (p :: Type -> Type) (i :: Type -> Type).
(Arithmetic a, Binary a, Binary (Rep p), Binary (Rep i),
 Ord (Rep i), NFData (Rep i), Representable i, Foldable i) =>
FieldElement (ArithmeticCircuit a p i)
-> Natural -> Gen (FieldElement (ArithmeticCircuit a p i))
arbitrary' (ArithmeticCircuit a p i Par1
-> FieldElement (ArithmeticCircuit a p i)
forall (c :: (Type -> Type) -> Type). c Par1 -> FieldElement c
FieldElement ArithmeticCircuit a p i Par1
ac) Natural

  ( Arithmetic a
  , Arbitrary a
  , Binary a
  , Binary (Rep p)
  , Arbitrary (Rep i)
  , Binary (Rep i)
  , Haskell.Ord (Rep i)
  , NFData (Rep i)
  , Representable i
  , Haskell.Foldable i
  , KnownNat l
  ) => Arbitrary (ArithmeticCircuit a p i (Vector l)) where
    arbitrary :: Gen (ArithmeticCircuit a p i (Vector l))
arbitrary = do
        ArithmeticCircuit a p i Par1
ac <- forall a. Arbitrary a => Gen a
arbitrary @(ArithmeticCircuit a p i Par1)
        Vector l (SysVar i)
o  <- [SysVar i] -> Vector l (SysVar i)
forall (size :: Natural) a. [a] -> Vector size a
unsafeToVector ([SysVar i] -> Vector l (SysVar i))
-> Gen [SysVar i] -> Gen (Vector l (SysVar i))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Natural -> [SysVar i] -> Gen [SysVar i]
forall a. Natural -> [a] -> Gen [a]
genSubset (forall (n :: Natural). KnownNat n => Natural
value @l) (ArithmeticCircuit a p i Par1 -> [SysVar i]
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
(Representable i, Foldable i) =>
ArithmeticCircuit a p i o -> [SysVar i]
getAllVars ArithmeticCircuit a p i Par1
        ArithmeticCircuit a p i (Vector l)
-> Gen (ArithmeticCircuit a p i (Vector l))
forall a. a -> Gen a
forall (m :: Type -> Type) a. Monad m => a -> m a
return ArithmeticCircuit a p i Par1
ac {acOutput = toVar <$> o}

arbitrary' ::
  forall a p i .
  (Arithmetic a, Binary a, Binary (Rep p), Binary (Rep i), Haskell.Ord (Rep i), NFData (Rep i)) =>
  (Representable i, Haskell.Foldable i) =>
  FieldElement (ArithmeticCircuit a p i) -> Natural ->
  Gen (FieldElement (ArithmeticCircuit a p i))
arbitrary' :: forall a (p :: Type -> Type) (i :: Type -> Type).
(Arithmetic a, Binary a, Binary (Rep p), Binary (Rep i),
 Ord (Rep i), NFData (Rep i), Representable i, Foldable i) =>
FieldElement (ArithmeticCircuit a p i)
-> Natural -> Gen (FieldElement (ArithmeticCircuit a p i))
arbitrary' FieldElement (ArithmeticCircuit a p i)
ac Natural
0 = FieldElement (ArithmeticCircuit a p i)
-> Gen (FieldElement (ArithmeticCircuit a p i))
forall a. a -> Gen a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (FieldElement (ArithmeticCircuit a p i)
 -> Gen (FieldElement (ArithmeticCircuit a p i)))
-> FieldElement (ArithmeticCircuit a p i)
-> Gen (FieldElement (ArithmeticCircuit a p i))
forall a b. (a -> b) -> a -> b
$ FieldElement (ArithmeticCircuit a p i)
-> FieldElement (ArithmeticCircuit a p i)
-> Bool
-> FieldElement (ArithmeticCircuit a p i)
forall a. a -> a -> Bool -> a
bool FieldElement (ArithmeticCircuit a p i)
ac (FieldElement (ArithmeticCircuit a p i)
newF FieldElement (ArithmeticCircuit a p i)
-> FieldElement (ArithmeticCircuit a p i)
-> FieldElement (ArithmeticCircuit a p i)
forall a. MultiplicativeSemigroup a => a -> a -> a
* FieldElement (ArithmeticCircuit a p i)
newF) (Natural
numOfVars Natural -> Natural -> Bool
forall a. Ord a => a -> a -> Bool
< Natural
    vars :: [SysVar i]
vars = ArithmeticCircuit a p i Par1 -> [SysVar i]
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
(Representable i, Foldable i) =>
ArithmeticCircuit a p i o -> [SysVar i]
getAllVars (ArithmeticCircuit a p i Par1 -> [SysVar i])
-> ArithmeticCircuit a p i Par1 -> [SysVar i]
forall a b. (a -> b) -> a -> b
$ FieldElement (ArithmeticCircuit a p i)
-> ArithmeticCircuit a p i Par1
forall (c :: (Type -> Type) -> Type). FieldElement c -> c Par1
fromFieldElement FieldElement (ArithmeticCircuit a p i)
    numOfVars :: Natural
numOfVars = [SysVar i] -> Natural
forall (t :: Type -> Type) a. Foldable t => t a -> Natural
length [SysVar i]
    newF :: FieldElement (ArithmeticCircuit a p i)
newF = ArithmeticCircuit a p i Par1
-> FieldElement (ArithmeticCircuit a p i)
forall (c :: (Type -> Type) -> Type). c Par1 -> FieldElement c
FieldElement (FieldElement (ArithmeticCircuit a p i)
-> ArithmeticCircuit a p i Par1
forall (c :: (Type -> Type) -> Type). FieldElement c -> c Par1
fromFieldElement FieldElement (ArithmeticCircuit a p i)
ac) { acOutput = pure (toVar $ head vars)}
arbitrary' FieldElement (ArithmeticCircuit a p i)
ac Natural
iter = do
    let vars :: [SysVar i]
vars = ArithmeticCircuit a p i Par1 -> [SysVar i]
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
(Representable i, Foldable i) =>
ArithmeticCircuit a p i o -> [SysVar i]
getAllVars (FieldElement (ArithmeticCircuit a p i)
-> ArithmeticCircuit a p i Par1
forall (c :: (Type -> Type) -> Type). FieldElement c -> c Par1
fromFieldElement FieldElement (ArithmeticCircuit a p i)
    SysVar i
li <- [SysVar i] -> Gen (SysVar i)
forall a. [a] -> Gen a
elements [SysVar i]
    SysVar i
ri <- [SysVar i] -> Gen (SysVar i)
forall a. [a] -> Gen a
elements [SysVar i]
    let (FieldElement (ArithmeticCircuit a p i)
l, FieldElement (ArithmeticCircuit a p i)
r) = ( ArithmeticCircuit a p i Par1
-> FieldElement (ArithmeticCircuit a p i)
forall (c :: (Type -> Type) -> Type). c Par1 -> FieldElement c
FieldElement (FieldElement (ArithmeticCircuit a p i)
-> ArithmeticCircuit a p i Par1
forall (c :: (Type -> Type) -> Type). FieldElement c -> c Par1
fromFieldElement FieldElement (ArithmeticCircuit a p i)
ac) { acOutput = pure (toVar li)}
                 , ArithmeticCircuit a p i Par1
-> FieldElement (ArithmeticCircuit a p i)
forall (c :: (Type -> Type) -> Type). c Par1 -> FieldElement c
FieldElement (FieldElement (ArithmeticCircuit a p i)
-> ArithmeticCircuit a p i Par1
forall (c :: (Type -> Type) -> Type). FieldElement c -> c Par1
fromFieldElement FieldElement (ArithmeticCircuit a p i)
ac) { acOutput = pure (toVar ri)})
    let c :: FieldElement (ArithmeticCircuit a p i)
c = ArithmeticCircuit a p i Par1
-> FieldElement (ArithmeticCircuit a p i)
forall (c :: (Type -> Type) -> Type). c Par1 -> FieldElement c
FieldElement (FieldElement (ArithmeticCircuit a p i)
-> ArithmeticCircuit a p i Par1
forall (c :: (Type -> Type) -> Type). FieldElement c -> c Par1
fromFieldElement (FieldElement (ArithmeticCircuit a p i)
 -> ArithmeticCircuit a p i Par1)
-> FieldElement (ArithmeticCircuit a p i)
-> ArithmeticCircuit a p i Par1
forall a b. (a -> b) -> a -> b
$ FieldElement (ArithmeticCircuit a p i)
-> BaseField (ArithmeticCircuit a p i)
-> FieldElement (ArithmeticCircuit a p i)
forall (c :: (Type -> Type) -> Type).
Symbolic c =>
FieldElement c -> BaseField c -> FieldElement c
createRangeConstraint FieldElement (ArithmeticCircuit a p i)
ac (forall a b. FromConstant a b => a -> b
fromConstant @Natural Natural
10)) { acOutput = pure (toVar li)}

    FieldElement (ArithmeticCircuit a p i)
ac' <- [FieldElement (ArithmeticCircuit a p i)]
-> Gen (FieldElement (ArithmeticCircuit a p i))
forall a. [a] -> Gen a
elements [
        FieldElement (ArithmeticCircuit a p i)
l FieldElement (ArithmeticCircuit a p i)
-> FieldElement (ArithmeticCircuit a p i)
-> FieldElement (ArithmeticCircuit a p i)
forall a. AdditiveSemigroup a => a -> a -> a
+ FieldElement (ArithmeticCircuit a p i)
        , FieldElement (ArithmeticCircuit a p i)
l FieldElement (ArithmeticCircuit a p i)
-> FieldElement (ArithmeticCircuit a p i)
-> FieldElement (ArithmeticCircuit a p i)
forall a. MultiplicativeSemigroup a => a -> a -> a
* FieldElement (ArithmeticCircuit a p i)
        , FieldElement (ArithmeticCircuit a p i)
l FieldElement (ArithmeticCircuit a p i)
-> FieldElement (ArithmeticCircuit a p i)
-> FieldElement (ArithmeticCircuit a p i)
forall a. AdditiveGroup a => a -> a -> a
- FieldElement (ArithmeticCircuit a p i)
        , FieldElement (ArithmeticCircuit a p i)
l FieldElement (ArithmeticCircuit a p i)
-> FieldElement (ArithmeticCircuit a p i)
-> FieldElement (ArithmeticCircuit a p i)
forall a. Field a => a -> a -> a
// FieldElement (ArithmeticCircuit a p i)
        , FieldElement (ArithmeticCircuit a p i)
    FieldElement (ArithmeticCircuit a p i)
-> Natural -> Gen (FieldElement (ArithmeticCircuit a p i))
forall a (p :: Type -> Type) (i :: Type -> Type).
(Arithmetic a, Binary a, Binary (Rep p), Binary (Rep i),
 Ord (Rep i), NFData (Rep i), Representable i, Foldable i) =>
FieldElement (ArithmeticCircuit a p i)
-> Natural -> Gen (FieldElement (ArithmeticCircuit a p i))
arbitrary' FieldElement (ArithmeticCircuit a p i)
ac' (Natural
iter Natural -> Natural -> Natural
-! Natural

createRangeConstraint :: Symbolic c => FieldElement c -> BaseField c -> FieldElement c
createRangeConstraint :: forall (c :: (Type -> Type) -> Type).
Symbolic c =>
FieldElement c -> BaseField c -> FieldElement c
createRangeConstraint (FieldElement c Par1
x) BaseField c
a = c Par1 -> FieldElement c
forall (c :: (Type -> Type) -> Type). c Par1 -> FieldElement c
FieldElement (c Par1 -> FieldElement c) -> c Par1 -> FieldElement c
forall a b. (a -> b) -> a -> b
$ c Par1 -> CircuitFun '[Par1] Par1 c -> c Par1
forall (f :: Type -> Type) (g :: Type -> Type).
c f -> CircuitFun '[f] g c -> c g
forall (c :: (Type -> Type) -> Type) (f :: Type -> Type)
       (g :: Type -> Type).
Symbolic c =>
c f -> CircuitFun '[f] g c -> c g
fromCircuitF c Par1
x (\ (Par1 i
v) ->  i -> Par1 i
forall p. p -> Par1 p
Par1 (i -> Par1 i) -> m i -> m (Par1 i)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> i -> BaseField c -> m i
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
var -> a -> m var
solve i
v BaseField c
    solve :: MonadCircuit var a w m => var -> a -> m var
    solve :: forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
var -> a -> m var
solve var
v a
b = do
v' <- ClosedPoly var a -> m var
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
ClosedPoly var a -> m var
newAssigned (x -> (var -> x) -> x
forall a b. a -> b -> a
Haskell.const x
forall a. AdditiveMonoid a => a
      var -> a -> m ()
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
var -> a -> m ()
rangeConstraint var
v' a
      var -> m var
forall a. a -> m a
forall (m :: Type -> Type) a. Monad m => a -> m a
return var

-- TODO: make it more readable
instance (Show a, Show (o (Var a i)), Show (Var a i), Show (Rep i), Haskell.Ord (Rep i)) => Show (ArithmeticCircuit a p i o) where
    show :: ArithmeticCircuit a p i o -> String
show ArithmeticCircuit a p i o
r = String
"ArithmeticCircuit { acSystem = " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Map ByteString (Constraint a i) -> String
forall a. Show a => a -> String
show (ArithmeticCircuit a p i o -> Map ByteString (Constraint a i)
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> Map ByteString (Constraint a i)
acSystem ArithmeticCircuit a p i o
                          String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"\n, acRange = " String -> ShowS
forall a. [a] -> [a] -> [a]
++ MonoidalMap (LookupType a) (Set [SysVar i]) -> String
forall a. Show a => a -> String
show (ArithmeticCircuit a p i o
-> MonoidalMap (LookupType a) (Set [SysVar i])
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o
-> MonoidalMap (LookupType a) (Set [SysVar i])
acLookup ArithmeticCircuit a p i o
                          String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"\n, acOutput = " String -> ShowS
forall a. [a] -> [a] -> [a]
++ o (Var a i) -> String
forall a. Show a => a -> String
show (ArithmeticCircuit a p i o -> o (Var a i)
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> o (Var a i)
acOutput ArithmeticCircuit a p i o
                          String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" }"

-- TODO: add witness generation info to the JSON object
instance (ToJSON a, ToJSON (o (Var a i)), ToJSONKey a, FromJSONKey (Var a i), ToJSON (Rep i), ToJSON (LookupType a), ToJSONKey (LookupType a)) => ToJSON (ArithmeticCircuit a p i o) where
    toJSON :: ArithmeticCircuit a p i o -> Value
toJSON ArithmeticCircuit a p i o
r = [Pair] -> Value
"system" Key -> Map ByteString (Constraint a i) -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
forall v. ToJSON v => Key -> v -> Pair
.= ArithmeticCircuit a p i o -> Map ByteString (Constraint a i)
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> Map ByteString (Constraint a i)
acSystem ArithmeticCircuit a p i o
"lookup" Key -> MonoidalMap (LookupType a) (Set [SysVar i]) -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
forall v. ToJSON v => Key -> v -> Pair
.= ArithmeticCircuit a p i o
-> MonoidalMap (LookupType a) (Set [SysVar i])
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o
-> MonoidalMap (LookupType a) (Set [SysVar i])
acLookup ArithmeticCircuit a p i o
"output" Key -> o (Var a i) -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
forall v. ToJSON v => Key -> v -> Pair
.= ArithmeticCircuit a p i o -> o (Var a i)
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> o (Var a i)
acOutput ArithmeticCircuit a p i o

-- TODO: properly restore the witness generation function
instance (FromJSON a, FromJSON (o (Var a i)), ToJSONKey (Var a i), FromJSONKey a, Haskell.Ord a, Haskell.Ord (Rep i), FromJSON (Rep i)) => FromJSON (ArithmeticCircuit a p i o) where
    parseJSON :: Value -> Parser (ArithmeticCircuit a p i o)
parseJSON =
-> (Object -> Parser (ArithmeticCircuit a p i o))
-> Value
-> Parser (ArithmeticCircuit a p i o)
forall a. String -> (Object -> Parser a) -> Value -> Parser a
withObject String
"ArithmeticCircuit" ((Object -> Parser (ArithmeticCircuit a p i o))
 -> Value -> Parser (ArithmeticCircuit a p i o))
-> (Object -> Parser (ArithmeticCircuit a p i o))
-> Value
-> Parser (ArithmeticCircuit a p i o)
forall a b. (a -> b) -> a -> b
$ \Object
v -> do
            Map ByteString (Constraint a i)
acSystem   <- Object
v Object -> Key -> Parser (Map ByteString (Constraint a i))
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
            MonoidalMap (LookupType a) (Set [SysVar i])
acLookup   <- Object
v Object
-> Key -> Parser (MonoidalMap (LookupType a) (Set [SysVar i]))
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
            o (Var a i)
acOutput   <- Object
v Object -> Key -> Parser (o (Var a i))
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
            let acWitness :: Map k a
acWitness        = Map k a
forall k a. Map k a
                acFold :: Map k a
acFold           = Map k a
forall k a. Map k a
                acLookupFunction :: Map k a
acLookupFunction = Map k a
forall k a. Map k a
            ArithmeticCircuit a p i o -> Parser (ArithmeticCircuit a p i o)
forall a. a -> Parser a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ArithmeticCircuit{o (Var a i)
Map ByteString (Constraint a i)
Map ByteString (CircuitWitness a p i)
Map ByteString (LookupFunction a)
Map ByteString (CircuitFold a (Var a i) (CircuitWitness a p i))
MonoidalMap (LookupType a) (Set [SysVar i])
forall k a. Map k a
acOutput :: o (Var a i)
acSystem :: Map ByteString (Constraint a i)
acLookup :: MonoidalMap (LookupType a) (Set [SysVar i])
acSystem :: Map ByteString (Constraint a i)
acLookup :: MonoidalMap (LookupType a) (Set [SysVar i])
acOutput :: o (Var a i)
acWitness :: forall k a. Map k a
acFold :: forall k a. Map k a
acLookupFunction :: forall k a. Map k a
acLookupFunction :: Map ByteString (LookupFunction a)
acWitness :: Map ByteString (CircuitWitness a p i)
acFold :: Map ByteString (CircuitFold a (Var a i) (CircuitWitness a p i))