{-# LANGUAGE AllowAmbiguousTypes  #-}
{-# LANGUAGE UndecidableInstances #-}

module ZkFold.Symbolic.Compiler.ArithmeticCircuit.Map (
        mapVarArithmeticCircuit,
    ) where

import           Data.Bifunctor                                      (bimap)
import           Data.Functor.Rep                                    (Representable (..))
import           Data.Map                                            hiding (drop, foldl, foldr, fromList, map, null,
                                                                      splitAt, take, toList)
import qualified Data.Map                                            as Map
import qualified Data.Set                                            as Set
import           GHC.IsList                                          (IsList (..))
import           Numeric.Natural                                     (Natural)
import           Prelude                                             hiding (Num (..), drop, length, product, splitAt,
                                                                      sum, take, (!!), (^))

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Base.Algebra.Polynomials.Multivariate
import           ZkFold.Base.Data.ByteString                         (toByteString)
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal

-- This module contains functions for mapping variables in arithmetic circuits.

mapVarArithmeticCircuit ::
  (Field a, Eq a, Functor o, Ord (Rep i), Representable i, Foldable i) =>
  ArithmeticCircuit a p i o -> ArithmeticCircuit a p i o
mapVarArithmeticCircuit :: forall a (o :: Type -> Type) (i :: Type -> Type)
       (p :: Type -> Type).
(Field a, Eq a, Functor o, Ord (Rep i), Representable i,
 Foldable i) =>
ArithmeticCircuit a p i o -> ArithmeticCircuit a p i o
mapVarArithmeticCircuit ArithmeticCircuit a p i o
ac =
    let vars :: [ByteString]
vars = [ByteString
v | NewVar (EqVar ByteString
v) <- ArithmeticCircuit a p i o -> [SysVar i]
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
(Representable i, Foldable i) =>
ArithmeticCircuit a p i o -> [SysVar i]
getAllVars ArithmeticCircuit a p i o
ac]
        asc :: [ByteString]
asc = [ forall a. Binary a => a -> ByteString
toByteString @VarField (forall a b. FromConstant a b => a -> b
fromConstant @Natural Natural
x) | Natural
x <- [Natural
0..] ]
        forward :: Map ByteString ByteString
forward = [(ByteString, ByteString)] -> Map ByteString ByteString
forall k a. Eq k => [(k, a)] -> Map k a
Map.fromAscList ([(ByteString, ByteString)] -> Map ByteString ByteString)
-> [(ByteString, ByteString)] -> Map ByteString ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [ByteString] -> [(ByteString, ByteString)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ByteString]
vars [ByteString]
asc
        backward :: Map ByteString ByteString
backward = [(ByteString, ByteString)] -> Map ByteString ByteString
forall k a. Eq k => [(k, a)] -> Map k a
Map.fromAscList ([(ByteString, ByteString)] -> Map ByteString ByteString)
-> [(ByteString, ByteString)] -> Map ByteString ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [ByteString] -> [(ByteString, ByteString)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ByteString]
asc [ByteString]
vars
        varF :: SysVar i -> SysVar i
varF (InVar Rep i
v)                     = Rep i -> SysVar i
forall (i :: Type -> Type). Rep i -> SysVar i
InVar Rep i
v
        varF (NewVar (EqVar ByteString
v))            = NewVar -> SysVar i
forall (i :: Type -> Type). NewVar -> SysVar i
NewVar (ByteString -> NewVar
EqVar (Map ByteString ByteString
forward Map ByteString ByteString -> ByteString -> ByteString
forall k a. Ord k => Map k a -> k -> a
! ByteString
v))
        -- | TODO: compress fold ids, too
        varF (NewVar (FoldVar ByteString
fldId ByteString
fldV)) = NewVar -> SysVar i
forall (i :: Type -> Type). NewVar -> SysVar i
NewVar (ByteString -> ByteString -> NewVar
FoldVar ByteString
fldId ByteString
fldV)
        oVarF :: Var a i -> Var a i
oVarF (LinVar a
k SysVar i
v a
b) = a -> SysVar i -> a -> Var a i
forall a (i :: Type -> Type). a -> SysVar i -> a -> Var a i
LinVar a
k (SysVar i -> SysVar i
varF SysVar i
v) a
b
        oVarF (ConstVar a
c)   = a -> Var a i
forall a (i :: Type -> Type). a -> Var a i
ConstVar a
c
        witF :: WitVar p i -> WitVar p i
witF (WSysVar SysVar i
v)    = SysVar i -> WitVar p i
forall (p :: Type -> Type) (i :: Type -> Type).
SysVar i -> WitVar p i
WSysVar (SysVar i -> SysVar i
varF SysVar i
v)
        witF (WExVar Rep p
v)     = Rep p -> WitVar p i
forall (p :: Type -> Type) (i :: Type -> Type). Rep p -> WitVar p i
WExVar Rep p
v
        -- | TODO: compress fold ids, too
        witF (WFoldVar ByteString
i ByteString
v) = ByteString -> ByteString -> WitVar p i
forall (p :: Type -> Type) (i :: Type -> Type).
ByteString -> ByteString -> WitVar p i
WFoldVar ByteString
i ByteString
v
     in ArithmeticCircuit
          { acLookup :: MonoidalMap (LookupType a) (Set [SysVar i])
acLookup   = ([SysVar i] -> [SysVar i]) -> Set [SysVar i] -> Set [SysVar i]
forall b a. Ord b => (a -> b) -> Set a -> Set b
Set.map ((SysVar i -> SysVar i) -> [SysVar i] -> [SysVar i]
forall a b. (a -> b) -> [a] -> [b]
map SysVar i -> SysVar i
varF) (Set [SysVar i] -> Set [SysVar i])
-> MonoidalMap (LookupType a) (Set [SysVar i])
-> MonoidalMap (LookupType a) (Set [SysVar i])
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> ArithmeticCircuit a p i o
-> MonoidalMap (LookupType a) (Set [SysVar i])
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o
-> MonoidalMap (LookupType a) (Set [SysVar i])
acLookup ArithmeticCircuit a p i o
ac
          , acLookupFunction :: Map ByteString (LookupFunction a)
acLookupFunction = ArithmeticCircuit a p i o -> Map ByteString (LookupFunction a)
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> Map ByteString (LookupFunction a)
acLookupFunction ArithmeticCircuit a p i o
ac
          , acSystem :: Map ByteString (Constraint a i)
acSystem  = [Item (Map ByteString (Constraint a i))]
-> Map ByteString (Constraint a i)
forall l. IsList l => [Item l] -> l
fromList ([Item (Map ByteString (Constraint a i))]
 -> Map ByteString (Constraint a i))
-> [Item (Map ByteString (Constraint a i))]
-> Map ByteString (Constraint a i)
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [Constraint a i] -> [(ByteString, Constraint a i)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ByteString]
asc ([Constraint a i] -> [(ByteString, Constraint a i)])
-> [Constraint a i] -> [(ByteString, Constraint a i)]
forall a b. (a -> b) -> a -> b
$ ((SysVar i -> Constraint a i)
 -> Mono (SysVar i) Natural -> Constraint a i)
-> (SysVar i -> Constraint a i) -> Constraint a i -> Constraint a i
forall c i j b.
(AdditiveMonoid b, Scale c b) =>
((i -> b) -> Mono i j -> b) -> (i -> b) -> Poly c i j -> b
evalPolynomial (SysVar i -> Constraint a i)
-> Mono (SysVar i) Natural -> Constraint a i
forall i j b.
(MultiplicativeMonoid b, Exponent b j) =>
(i -> b) -> Mono i j -> b
evalMonomial (SysVar i -> Constraint a i
forall c i j. Polynomial c i j => i -> Poly c i j
var (SysVar i -> Constraint a i)
-> (SysVar i -> SysVar i) -> SysVar i -> Constraint a i
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SysVar i -> SysVar i
varF) (Constraint a i -> Constraint a i)
-> [Constraint a i] -> [Constraint a i]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Map ByteString (Constraint a i) -> [Constraint a i]
forall k a. Map k a -> [a]
elems (ArithmeticCircuit a p i o -> Map ByteString (Constraint a i)
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> Map ByteString (Constraint a i)
acSystem ArithmeticCircuit a p i o
ac)
          , acWitness :: Map ByteString (CircuitWitness a p i)
acWitness = ((WitVar p i -> WitVar p i)
-> CircuitWitness a p i -> CircuitWitness a p i
forall a b. (a -> b) -> WitnessF a a -> WitnessF a b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap WitVar p i -> WitVar p i
witF (CircuitWitness a p i -> CircuitWitness a p i)
-> Map ByteString (CircuitWitness a p i)
-> Map ByteString (CircuitWitness a p i)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> ArithmeticCircuit a p i o -> Map ByteString (CircuitWitness a p i)
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> Map ByteString (CircuitWitness a p i)
acWitness ArithmeticCircuit a p i o
ac) Map ByteString (CircuitWitness a p i)
-> Map ByteString ByteString
-> Map ByteString (CircuitWitness a p i)
forall b c a. Ord b => Map b c -> Map a b -> Map a c
`Map.compose` Map ByteString ByteString
backward
          , acFold :: Map ByteString (CircuitFold a (Var a i) (CircuitWitness a p i))
acFold = (Var a i -> Var a i)
-> (CircuitWitness a p i -> CircuitWitness a p i)
-> CircuitFold a (Var a i) (CircuitWitness a p i)
-> CircuitFold a (Var a i) (CircuitWitness a p i)
forall a b c d.
(a -> b) -> (c -> d) -> CircuitFold a a c -> CircuitFold a b d
forall (p :: Type -> Type -> Type) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap Var a i -> Var a i
oVarF ((WitVar p i -> WitVar p i)
-> CircuitWitness a p i -> CircuitWitness a p i
forall a b. (a -> b) -> WitnessF a a -> WitnessF a b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap WitVar p i -> WitVar p i
witF) (CircuitFold a (Var a i) (CircuitWitness a p i)
 -> CircuitFold a (Var a i) (CircuitWitness a p i))
-> Map ByteString (CircuitFold a (Var a i) (CircuitWitness a p i))
-> Map ByteString (CircuitFold a (Var a i) (CircuitWitness a p i))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> ArithmeticCircuit a p i o
-> Map ByteString (CircuitFold a (Var a i) (CircuitWitness a p i))
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o
-> Map ByteString (CircuitFold a (Var a i) (CircuitWitness a p i))
acFold ArithmeticCircuit a p i o
ac
          , acOutput :: o (Var a i)
acOutput  = Var a i -> Var a i
oVarF (Var a i -> Var a i) -> o (Var a i) -> o (Var a i)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> ArithmeticCircuit a p i o -> o (Var a i)
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> o (Var a i)
acOutput ArithmeticCircuit a p i o
ac
          }