{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE UndecidableInstances #-}
module ZkFold.Symbolic.Compiler.ArithmeticCircuit.Map (
mapVarArithmeticCircuit,
) where
import Data.Bifunctor (bimap)
import Data.Functor.Rep (Representable (..))
import Data.Map hiding (drop, foldl, foldr, fromList, map, null,
splitAt, take, toList)
import qualified Data.Map as Map
import qualified Data.Set as Set
import GHC.IsList (IsList (..))
import Numeric.Natural (Natural)
import Prelude hiding (Num (..), drop, length, product, splitAt,
sum, take, (!!), (^))
import ZkFold.Base.Algebra.Basic.Class
import ZkFold.Base.Algebra.Polynomials.Multivariate
import ZkFold.Base.Data.ByteString (toByteString)
import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal
mapVarArithmeticCircuit ::
(Field a, Eq a, Functor o, Ord (Rep i), Representable i, Foldable i) =>
ArithmeticCircuit a p i o -> ArithmeticCircuit a p i o
mapVarArithmeticCircuit :: forall a (o :: Type -> Type) (i :: Type -> Type)
(p :: Type -> Type).
(Field a, Eq a, Functor o, Ord (Rep i), Representable i,
Foldable i) =>
ArithmeticCircuit a p i o -> ArithmeticCircuit a p i o
mapVarArithmeticCircuit ArithmeticCircuit a p i o
ac =
let vars :: [ByteString]
vars = [ByteString
v | NewVar (EqVar ByteString
v) <- ArithmeticCircuit a p i o -> [SysVar i]
forall a (p :: Type -> Type) (i :: Type -> Type)
(o :: Type -> Type).
(Representable i, Foldable i) =>
ArithmeticCircuit a p i o -> [SysVar i]
getAllVars ArithmeticCircuit a p i o
ac]
asc :: [ByteString]
asc = [ forall a. Binary a => a -> ByteString
toByteString @VarField (forall a b. FromConstant a b => a -> b
fromConstant @Natural Natural
x) | Natural
x <- [Natural
0..] ]
forward :: Map ByteString ByteString
forward = [(ByteString, ByteString)] -> Map ByteString ByteString
forall k a. Eq k => [(k, a)] -> Map k a
Map.fromAscList ([(ByteString, ByteString)] -> Map ByteString ByteString)
-> [(ByteString, ByteString)] -> Map ByteString ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [ByteString] -> [(ByteString, ByteString)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ByteString]
vars [ByteString]
asc
backward :: Map ByteString ByteString
backward = [(ByteString, ByteString)] -> Map ByteString ByteString
forall k a. Eq k => [(k, a)] -> Map k a
Map.fromAscList ([(ByteString, ByteString)] -> Map ByteString ByteString)
-> [(ByteString, ByteString)] -> Map ByteString ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [ByteString] -> [(ByteString, ByteString)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ByteString]
asc [ByteString]
vars
varF :: SysVar i -> SysVar i
varF (InVar Rep i
v) = Rep i -> SysVar i
forall (i :: Type -> Type). Rep i -> SysVar i
InVar Rep i
v
varF (NewVar (EqVar ByteString
v)) = NewVar -> SysVar i
forall (i :: Type -> Type). NewVar -> SysVar i
NewVar (ByteString -> NewVar
EqVar (Map ByteString ByteString
forward Map ByteString ByteString -> ByteString -> ByteString
forall k a. Ord k => Map k a -> k -> a
! ByteString
v))
varF (NewVar (FoldVar ByteString
fldId ByteString
fldV)) = NewVar -> SysVar i
forall (i :: Type -> Type). NewVar -> SysVar i
NewVar (ByteString -> ByteString -> NewVar
FoldVar ByteString
fldId ByteString
fldV)
oVarF :: Var a i -> Var a i
oVarF (LinVar a
k SysVar i
v a
b) = a -> SysVar i -> a -> Var a i
forall a (i :: Type -> Type). a -> SysVar i -> a -> Var a i
LinVar a
k (SysVar i -> SysVar i
varF SysVar i
v) a
b
oVarF (ConstVar a
c) = a -> Var a i
forall a (i :: Type -> Type). a -> Var a i
ConstVar a
c
witF :: WitVar p i -> WitVar p i
witF (WSysVar SysVar i
v) = SysVar i -> WitVar p i
forall (p :: Type -> Type) (i :: Type -> Type).
SysVar i -> WitVar p i
WSysVar (SysVar i -> SysVar i
varF SysVar i
v)
witF (WExVar Rep p
v) = Rep p -> WitVar p i
forall (p :: Type -> Type) (i :: Type -> Type). Rep p -> WitVar p i
WExVar Rep p
v
witF (WFoldVar ByteString
i ByteString
v) = ByteString -> ByteString -> WitVar p i
forall (p :: Type -> Type) (i :: Type -> Type).
ByteString -> ByteString -> WitVar p i
WFoldVar ByteString
i ByteString
v
in ArithmeticCircuit
{ acLookup :: MonoidalMap (LookupType a) (Set [SysVar i])
acLookup = ([SysVar i] -> [SysVar i]) -> Set [SysVar i] -> Set [SysVar i]
forall b a. Ord b => (a -> b) -> Set a -> Set b
Set.map ((SysVar i -> SysVar i) -> [SysVar i] -> [SysVar i]
forall a b. (a -> b) -> [a] -> [b]
map SysVar i -> SysVar i
varF) (Set [SysVar i] -> Set [SysVar i])
-> MonoidalMap (LookupType a) (Set [SysVar i])
-> MonoidalMap (LookupType a) (Set [SysVar i])
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> ArithmeticCircuit a p i o
-> MonoidalMap (LookupType a) (Set [SysVar i])
forall a (p :: Type -> Type) (i :: Type -> Type)
(o :: Type -> Type).
ArithmeticCircuit a p i o
-> MonoidalMap (LookupType a) (Set [SysVar i])
acLookup ArithmeticCircuit a p i o
ac
, acLookupFunction :: Map ByteString (LookupFunction a)
acLookupFunction = ArithmeticCircuit a p i o -> Map ByteString (LookupFunction a)
forall a (p :: Type -> Type) (i :: Type -> Type)
(o :: Type -> Type).
ArithmeticCircuit a p i o -> Map ByteString (LookupFunction a)
acLookupFunction ArithmeticCircuit a p i o
ac
, acSystem :: Map ByteString (Constraint a i)
acSystem = [Item (Map ByteString (Constraint a i))]
-> Map ByteString (Constraint a i)
forall l. IsList l => [Item l] -> l
fromList ([Item (Map ByteString (Constraint a i))]
-> Map ByteString (Constraint a i))
-> [Item (Map ByteString (Constraint a i))]
-> Map ByteString (Constraint a i)
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [Constraint a i] -> [(ByteString, Constraint a i)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ByteString]
asc ([Constraint a i] -> [(ByteString, Constraint a i)])
-> [Constraint a i] -> [(ByteString, Constraint a i)]
forall a b. (a -> b) -> a -> b
$ ((SysVar i -> Constraint a i)
-> Mono (SysVar i) Natural -> Constraint a i)
-> (SysVar i -> Constraint a i) -> Constraint a i -> Constraint a i
forall c i j b.
(AdditiveMonoid b, Scale c b) =>
((i -> b) -> Mono i j -> b) -> (i -> b) -> Poly c i j -> b
evalPolynomial (SysVar i -> Constraint a i)
-> Mono (SysVar i) Natural -> Constraint a i
forall i j b.
(MultiplicativeMonoid b, Exponent b j) =>
(i -> b) -> Mono i j -> b
evalMonomial (SysVar i -> Constraint a i
forall c i j. Polynomial c i j => i -> Poly c i j
var (SysVar i -> Constraint a i)
-> (SysVar i -> SysVar i) -> SysVar i -> Constraint a i
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SysVar i -> SysVar i
varF) (Constraint a i -> Constraint a i)
-> [Constraint a i] -> [Constraint a i]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Map ByteString (Constraint a i) -> [Constraint a i]
forall k a. Map k a -> [a]
elems (ArithmeticCircuit a p i o -> Map ByteString (Constraint a i)
forall a (p :: Type -> Type) (i :: Type -> Type)
(o :: Type -> Type).
ArithmeticCircuit a p i o -> Map ByteString (Constraint a i)
acSystem ArithmeticCircuit a p i o
ac)
, acWitness :: Map ByteString (CircuitWitness a p i)
acWitness = ((WitVar p i -> WitVar p i)
-> CircuitWitness a p i -> CircuitWitness a p i
forall a b. (a -> b) -> WitnessF a a -> WitnessF a b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap WitVar p i -> WitVar p i
witF (CircuitWitness a p i -> CircuitWitness a p i)
-> Map ByteString (CircuitWitness a p i)
-> Map ByteString (CircuitWitness a p i)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> ArithmeticCircuit a p i o -> Map ByteString (CircuitWitness a p i)
forall a (p :: Type -> Type) (i :: Type -> Type)
(o :: Type -> Type).
ArithmeticCircuit a p i o -> Map ByteString (CircuitWitness a p i)
acWitness ArithmeticCircuit a p i o
ac) Map ByteString (CircuitWitness a p i)
-> Map ByteString ByteString
-> Map ByteString (CircuitWitness a p i)
forall b c a. Ord b => Map b c -> Map a b -> Map a c
`Map.compose` Map ByteString ByteString
backward
, acFold :: Map ByteString (CircuitFold a (Var a i) (CircuitWitness a p i))
acFold = (Var a i -> Var a i)
-> (CircuitWitness a p i -> CircuitWitness a p i)
-> CircuitFold a (Var a i) (CircuitWitness a p i)
-> CircuitFold a (Var a i) (CircuitWitness a p i)
forall a b c d.
(a -> b) -> (c -> d) -> CircuitFold a a c -> CircuitFold a b d
forall (p :: Type -> Type -> Type) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap Var a i -> Var a i
oVarF ((WitVar p i -> WitVar p i)
-> CircuitWitness a p i -> CircuitWitness a p i
forall a b. (a -> b) -> WitnessF a a -> WitnessF a b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap WitVar p i -> WitVar p i
witF) (CircuitFold a (Var a i) (CircuitWitness a p i)
-> CircuitFold a (Var a i) (CircuitWitness a p i))
-> Map ByteString (CircuitFold a (Var a i) (CircuitWitness a p i))
-> Map ByteString (CircuitFold a (Var a i) (CircuitWitness a p i))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> ArithmeticCircuit a p i o
-> Map ByteString (CircuitFold a (Var a i) (CircuitWitness a p i))
forall a (p :: Type -> Type) (i :: Type -> Type)
(o :: Type -> Type).
ArithmeticCircuit a p i o
-> Map ByteString (CircuitFold a (Var a i) (CircuitWitness a p i))
acFold ArithmeticCircuit a p i o
ac
, acOutput :: o (Var a i)
acOutput = Var a i -> Var a i
oVarF (Var a i -> Var a i) -> o (Var a i) -> o (Var a i)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> ArithmeticCircuit a p i o -> o (Var a i)
forall a (p :: Type -> Type) (i :: Type -> Type)
(o :: Type -> Type).
ArithmeticCircuit a p i o -> o (Var a i)
acOutput ArithmeticCircuit a p i o
ac
}