module ZkFold.Symbolic.Compiler.ArithmeticCircuit.Optimization where

import           Data.Binary                                             (Binary)
import           Data.Bool                                               (bool)
import           Data.ByteString                                         (ByteString)
import           Data.Functor                                            ((<&>))
import           Data.Functor.Rep                                        (Representable (..))
import           Data.Map                                                hiding (drop, foldl, foldr, map, null, splitAt,
                                                                          take)
import qualified Data.Map.Internal                                       as M
import qualified Data.Map.Monoidal                                       as MM
import qualified Data.Set                                                as S
import           Prelude                                                 hiding (Num (..), drop, length, product,
                                                                          splitAt, sum, take, (!!), (^))

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Base.Algebra.Basic.Number
import           ZkFold.Base.Algebra.Polynomials.Multivariate            (evalMonomial)
import           ZkFold.Base.Algebra.Polynomials.Multivariate.Monomial   (Mono (..), oneM)
import           ZkFold.Base.Algebra.Polynomials.Multivariate.Polynomial (Poly (..), evalPolynomial, var)
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit.Instance     ()
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit.Witness      (WitnessF (..))

--------------------------------- High-level functions --------------------------------

-- | Replaces linear polynoms of the form
-- @(fromConstant k) * (NewVar nV) + (fromConstant c)@
-- with a constant variable @fromConstant $ negate c // k@ in an arithmetic circuit
-- and replaces variable with a constant in witness
--
optimize :: forall a p i o.
  (Arithmetic a, Ord (Rep i), Functor o, Binary (Rep i), Binary a, Binary (Rep p)) =>
  ArithmeticCircuit a p i o -> ArithmeticCircuit a p i o
optimize :: forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
(Arithmetic a, Ord (Rep i), Functor o, Binary (Rep i), Binary a,
 Binary (Rep p)) =>
ArithmeticCircuit a p i o -> ArithmeticCircuit a p i o
optimize (ArithmeticCircuit Map ByteString (Poly a (SysVar i) Natural)
s MonoidalMap a (Set (SysVar i))
r Map ByteString (WitnessF a (WitVar p i))
w o (Var a i)
o) = ArithmeticCircuit {
    acSystem :: Map ByteString (Poly a (SysVar i) Natural)
acSystem = Map ByteString (Poly a (SysVar i) Natural)
-> Map ByteString (Poly a (SysVar i) Natural)
addInVarConstraints Map ByteString (Poly a (SysVar i) Natural)
newS,
    acRange :: MonoidalMap a (Set (SysVar i))
acRange = Map (SysVar i) a
-> MonoidalMap a (Set (SysVar i)) -> MonoidalMap a (Set (SysVar i))
optRanges Map (SysVar i) a
vs MonoidalMap a (Set (SysVar i))
r,
    acWitness :: Map ByteString (WitnessF a (WitVar p i))
acWitness = (WitnessF a (WitVar p i)
-> (WitVar p i -> WitnessF a (WitVar p i))
-> WitnessF a (WitVar p i)
forall a b. WitnessF a a -> (a -> WitnessF a b) -> WitnessF a b
forall (m :: Type -> Type) a b. Monad m => m a -> (a -> m b) -> m b
>>= Map (SysVar i) a -> WitVar p i -> WitnessF a (WitVar p i)
optWitVar Map (SysVar i) a
vs) (WitnessF a (WitVar p i) -> WitnessF a (WitVar p i))
-> Map ByteString (WitnessF a (WitVar p i))
-> Map ByteString (WitnessF a (WitVar p i))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$>  (ByteString -> WitnessF a (WitVar p i) -> Bool)
-> Map ByteString (WitnessF a (WitVar p i))
-> Map ByteString (WitnessF a (WitVar p i))
forall k a. (k -> a -> Bool) -> Map k a -> Map k a
M.filterWithKey (\ByteString
k WitnessF a (WitVar p i)
_ -> SysVar i -> Map (SysVar i) a -> Bool
forall k a. Ord k => k -> Map k a -> Bool
notMember (ByteString -> SysVar i
forall (i :: Type -> Type). ByteString -> SysVar i
NewVar ByteString
k) Map (SysVar i) a
vs) Map ByteString (WitnessF a (WitVar p i))
w,
    acOutput :: o (Var a i)
acOutput = o (Var a i)
o o (Var a i) -> (Var a i -> Var a i) -> o (Var a i)
forall (f :: Type -> Type) a b. Functor f => f a -> (a -> b) -> f b
<&> \case
      SysVar SysVar i
sV -> Var a i -> (a -> Var a i) -> Maybe a -> Var a i
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (SysVar i -> Var a i
forall a (i :: Type -> Type). SysVar i -> Var a i
SysVar SysVar i
sV) a -> Var a i
forall a (i :: Type -> Type). a -> Var a i
ConstVar (SysVar i -> Map (SysVar i) a -> Maybe a
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup SysVar i
sV Map (SysVar i) a
vs)
      Var a i
so -> Var a i
so}
  where
    (Map ByteString (Poly a (SysVar i) Natural)
newS, Map (SysVar i) a
vs) = (Map ByteString (Poly a (SysVar i) Natural), Map (SysVar i) a)
-> (Map ByteString (Poly a (SysVar i) Natural), Map (SysVar i) a)
forall a (i :: Type -> Type).
(Arithmetic a, Ord (Rep i)) =>
(Map ByteString (Constraint a i), Map (SysVar i) a)
-> (Map ByteString (Constraint a i), Map (SysVar i) a)
varsToReplace (Map ByteString (Poly a (SysVar i) Natural)
s, Map (SysVar i) a
forall k a. Map k a
M.empty)

    isInVar :: SysVar i -> Bool
isInVar (InVar Rep i
_) = Bool
True
    isInVar SysVar i
_         = Bool
False

    addInVarConstraints :: Map ByteString (Poly a (SysVar i) Natural) -> Map ByteString (Poly a (SysVar i) Natural)
    addInVarConstraints :: Map ByteString (Poly a (SysVar i) Natural)
-> Map ByteString (Poly a (SysVar i) Natural)
addInVarConstraints Map ByteString (Poly a (SysVar i) Natural)
p = Map ByteString (Poly a (SysVar i) Natural)
p Map ByteString (Poly a (SysVar i) Natural)
-> Map ByteString (Poly a (SysVar i) Natural)
-> Map ByteString (Poly a (SysVar i) Natural)
forall a. Semigroup a => a -> a -> a
<> [(ByteString, Poly a (SysVar i) Natural)]
-> Map ByteString (Poly a (SysVar i) Natural)
forall k a. Ord k => [(k, a)] -> Map k a
fromList [(ByteString
polyId, Poly a (SysVar i) Natural
poly) | (SysVar i
inVar, a
v) <- Map (SysVar i) a -> [(SysVar i, a)]
forall k a. Map k a -> [(k, a)]
assocs (Map (SysVar i) a -> [(SysVar i, a)])
-> Map (SysVar i) a -> [(SysVar i, a)]
forall a b. (a -> b) -> a -> b
$ (SysVar i -> a -> Bool) -> Map (SysVar i) a -> Map (SysVar i) a
forall k a. (k -> a -> Bool) -> Map k a -> Map k a
filterWithKey (Bool -> a -> Bool
forall a b. a -> b -> a
const (Bool -> a -> Bool) -> (SysVar i -> Bool) -> SysVar i -> a -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SysVar i -> Bool
forall {i :: Type -> Type}. SysVar i -> Bool
isInVar) Map (SysVar i) a
vs,
                                                            let poly :: Poly a (SysVar i) Natural
poly = SysVar i -> Poly a (SysVar i) Natural
forall c i j. Polynomial c i j => i -> Poly c i j
var SysVar i
inVar Poly a (SysVar i) Natural
-> Poly a (SysVar i) Natural -> Poly a (SysVar i) Natural
forall a. AdditiveGroup a => a -> a -> a
- a -> Poly a (SysVar i) Natural
forall a b. FromConstant a b => a -> b
fromConstant a
v,
                                                            let polyId :: ByteString
polyId = forall a (p :: Type -> Type) (i :: Type -> Type).
(Finite a, Binary a, Binary (Rep p), Binary (Rep i)) =>
WitnessF a (WitVar p i) -> ByteString
toVar @a @p @i (WitVar p i -> WitnessF a (WitVar p i)
forall a. a -> WitnessF a a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (SysVar i -> WitVar p i
forall (p :: Type -> Type) (i :: Type -> Type).
SysVar i -> WitVar p i
WSysVar SysVar i
inVar) WitnessF a (WitVar p i)
-> WitnessF a (WitVar p i) -> WitnessF a (WitVar p i)
forall a. AdditiveGroup a => a -> a -> a
- a -> WitnessF a (WitVar p i)
forall a b. FromConstant a b => a -> b
fromConstant a
v)]

    optRanges :: Map (SysVar i) a -> MM.MonoidalMap a (S.Set (SysVar i)) -> MM.MonoidalMap a (S.Set (SysVar i))
    optRanges :: Map (SysVar i) a
-> MonoidalMap a (Set (SysVar i)) -> MonoidalMap a (Set (SysVar i))
optRanges Map (SysVar i) a
m = (a -> Set (SysVar i) -> Maybe (Set (SysVar i)))
-> MonoidalMap a (Set (SysVar i)) -> MonoidalMap a (Set (SysVar i))
forall k a b.
(k -> a -> Maybe b) -> MonoidalMap k a -> MonoidalMap k b
MM.mapMaybeWithKey (\a
k Set (SysVar i)
v -> Maybe (Set (SysVar i))
-> Maybe (Set (SysVar i)) -> Bool -> Maybe (Set (SysVar i))
forall a. a -> a -> Bool -> a
bool ([Char] -> Maybe (Set (SysVar i))
forall a. HasCallStack => [Char] -> a
error [Char]
"range constraint less then value")
      (let t :: Set (SysVar i)
t = Set (SysVar i) -> Set (SysVar i) -> Set (SysVar i)
forall a. Ord a => Set a -> Set a -> Set a
S.difference Set (SysVar i)
v (Set (SysVar i) -> Set (SysVar i))
-> Set (SysVar i) -> Set (SysVar i)
forall a b. (a -> b) -> a -> b
$ Map (SysVar i) a -> Set (SysVar i)
forall k a. Map k a -> Set k
keysSet Map (SysVar i) a
m in if Set (SysVar i) -> Bool
forall a. Set a -> Bool
forall (t :: Type -> Type) a. Foldable t => t a -> Bool
null Set (SysVar i)
t then Maybe (Set (SysVar i))
forall a. Maybe a
Nothing else Set (SysVar i) -> Maybe (Set (SysVar i))
forall a. a -> Maybe a
Just Set (SysVar i)
t) ((a -> Bool) -> Map (SysVar i) a -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
all (a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
k) (Map (SysVar i) a -> Bool) -> Map (SysVar i) a -> Bool
forall a b. (a -> b) -> a -> b
$ Map (SysVar i) a -> Set (SysVar i) -> Map (SysVar i) a
forall k a. Ord k => Map k a -> Set k -> Map k a
restrictKeys Map (SysVar i) a
m Set (SysVar i)
v))

    optWitVar :: Map (SysVar i) a -> WitVar p i -> WitnessF a (WitVar p i)
    optWitVar :: Map (SysVar i) a -> WitVar p i -> WitnessF a (WitVar p i)
optWitVar Map (SysVar i) a
m = \case
      (WSysVar SysVar i
sv) ->
        case SysVar i -> Map (SysVar i) a -> Maybe a
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup SysVar i
sv Map (SysVar i) a
m of
          Just a
k  -> a -> WitnessF a (WitVar p i)
forall a b. FromConstant a b => a -> b
fromConstant a
k
          Maybe a
Nothing -> WitVar p i -> WitnessF a (WitVar p i)
forall a. a -> WitnessF a a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (WitVar p i -> WitnessF a (WitVar p i))
-> WitVar p i -> WitnessF a (WitVar p i)
forall a b. (a -> b) -> a -> b
$ SysVar i -> WitVar p i
forall (p :: Type -> Type) (i :: Type -> Type).
SysVar i -> WitVar p i
WSysVar SysVar i
sv
      WitVar p i
we  -> WitVar p i -> WitnessF a (WitVar p i)
forall a. a -> WitnessF a a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure WitVar p i
we

varsToReplace :: (Arithmetic a, Ord (Rep i)) => (Map ByteString (Constraint a i) , Map (SysVar i) a) -> (Map ByteString (Constraint a i) , Map (SysVar i) a)
varsToReplace :: forall a (i :: Type -> Type).
(Arithmetic a, Ord (Rep i)) =>
(Map ByteString (Constraint a i), Map (SysVar i) a)
-> (Map ByteString (Constraint a i), Map (SysVar i) a)
varsToReplace (Map ByteString (Constraint a i)
s, Map (SysVar i) a
l) = if Map (SysVar i) a
newVars Map (SysVar i) a -> Map (SysVar i) a -> Bool
forall a. Eq a => a -> a -> Bool
== Map (SysVar i) a
forall k a. Map k a
M.empty then (Map ByteString (Constraint a i)
s, Map (SysVar i) a
l) else (Map ByteString (Constraint a i), Map (SysVar i) a)
-> (Map ByteString (Constraint a i), Map (SysVar i) a)
forall a (i :: Type -> Type).
(Arithmetic a, Ord (Rep i)) =>
(Map ByteString (Constraint a i), Map (SysVar i) a)
-> (Map ByteString (Constraint a i), Map (SysVar i) a)
varsToReplace ((Constraint a i -> Bool)
-> Map ByteString (Constraint a i)
-> Map ByteString (Constraint a i)
forall a k. (a -> Bool) -> Map k a -> Map k a
M.filter (Constraint a i -> Constraint a i -> Bool
forall a. Eq a => a -> a -> Bool
/= Constraint a i
forall a. AdditiveMonoid a => a
zero) (Map ByteString (Constraint a i)
 -> Map ByteString (Constraint a i))
-> Map ByteString (Constraint a i)
-> Map ByteString (Constraint a i)
forall a b. (a -> b) -> a -> b
$ Map (SysVar i) a
-> Map ByteString (Constraint a i)
-> Map ByteString (Constraint a i)
forall a (i :: Type -> Type).
(Arithmetic a, Ord (Rep i)) =>
Map (SysVar i) a
-> Map ByteString (Constraint a i)
-> Map ByteString (Constraint a i)
optimizeSystems Map (SysVar i) a
newVars Map ByteString (Constraint a i)
s, Map (SysVar i) a -> Map (SysVar i) a -> Map (SysVar i) a
forall k a. Ord k => Map k a -> Map k a -> Map k a
M.union Map (SysVar i) a
newVars Map (SysVar i) a
l)
  where
    newVars :: Map (SysVar i) a
newVars = [(SysVar i, a)] -> Map (SysVar i) a
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(SysVar i, a)] -> Map (SysVar i) a)
-> (Map ByteString (SysVar i, a) -> [(SysVar i, a)])
-> Map ByteString (SysVar i, a)
-> Map (SysVar i) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map ByteString (SysVar i, a) -> [(SysVar i, a)]
forall k a. Map k a -> [a]
M.elems (Map ByteString (SysVar i, a) -> Map (SysVar i) a)
-> Map ByteString (SysVar i, a) -> Map (SysVar i) a
forall a b. (a -> b) -> a -> b
$ (Constraint a i -> Maybe (SysVar i, a))
-> Map ByteString (Constraint a i) -> Map ByteString (SysVar i, a)
forall a b k. (a -> Maybe b) -> Map k a -> Map k b
mapMaybe Constraint a i -> Maybe (SysVar i, a)
forall a (i :: Type -> Type).
Arithmetic a =>
Constraint a i -> Maybe (SysVar i, a)
toConstVar Map ByteString (Constraint a i)
s

    optimizeSystems :: (Arithmetic a, Ord (Rep i)) => Map (SysVar i) a -> Map ByteString (Constraint a i) -> Map ByteString (Constraint a i)
    optimizeSystems :: forall a (i :: Type -> Type).
(Arithmetic a, Ord (Rep i)) =>
Map (SysVar i) a
-> Map ByteString (Constraint a i)
-> Map ByteString (Constraint a i)
optimizeSystems Map (SysVar i) a
m Map ByteString (Poly a (SysVar i) Natural)
as = Map ByteString (Poly a (SysVar i) Natural)
-> Map ByteString (Poly a (SysVar i) Natural)
-> Bool
-> Map ByteString (Poly a (SysVar i) Natural)
forall a. a -> a -> Bool -> a
bool ([Char] -> Map ByteString (Poly a (SysVar i) Natural)
forall a. HasCallStack => [Char] -> a
error [Char]
"unsatisfiable constraint") Map ByteString (Poly a (SysVar i) Natural)
ns ((Poly a (SysVar i) Natural -> Bool)
-> Map ByteString (Poly a (SysVar i) Natural) -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
all Poly a (SysVar i) Natural -> Bool
forall {a} {i} {j}. (Eq a, AdditiveMonoid a) => Poly a i j -> Bool
checkZero Map ByteString (Poly a (SysVar i) Natural)
ns)
      where
        ns :: Map ByteString (Poly a (SysVar i) Natural)
ns = ((SysVar i -> Poly a (SysVar i) Natural)
 -> Mono (SysVar i) Natural -> Poly a (SysVar i) Natural)
-> (SysVar i -> Poly a (SysVar i) Natural)
-> Poly a (SysVar i) Natural
-> Poly a (SysVar i) Natural
forall c i j b.
(AdditiveMonoid b, Scale c b) =>
((i -> b) -> Mono i j -> b) -> (i -> b) -> Poly c i j -> b
evalPolynomial (SysVar i -> Poly a (SysVar i) Natural)
-> Mono (SysVar i) Natural -> Poly a (SysVar i) Natural
forall i j b.
(MultiplicativeMonoid b, Exponent b j) =>
(i -> b) -> Mono i j -> b
evalMonomial SysVar i -> Poly a (SysVar i) Natural
varF (Poly a (SysVar i) Natural -> Poly a (SysVar i) Natural)
-> Map ByteString (Poly a (SysVar i) Natural)
-> Map ByteString (Poly a (SysVar i) Natural)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Map ByteString (Poly a (SysVar i) Natural)
as
        varF :: SysVar i -> Poly a (SysVar i) Natural
varF SysVar i
p = Poly a (SysVar i) Natural
-> (a -> Poly a (SysVar i) Natural)
-> Maybe a
-> Poly a (SysVar i) Natural
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (SysVar i -> Poly a (SysVar i) Natural
forall c i j. Polynomial c i j => i -> Poly c i j
var SysVar i
p) a -> Poly a (SysVar i) Natural
forall a b. FromConstant a b => a -> b
fromConstant (SysVar i -> Map (SysVar i) a -> Maybe a
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup SysVar i
p Map (SysVar i) a
m)
        checkZero :: Poly a i j -> Bool
checkZero (P [(a
c, Mono i j
mx)]) = (a
c a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
forall a. AdditiveMonoid a => a
zero) Bool -> Bool -> Bool
&& Mono i j -> Bool
forall i j. Mono i j -> Bool
oneM Mono i j
mx Bool -> Bool -> Bool
|| Bool -> Bool
not (Mono i j -> Bool
forall i j. Mono i j -> Bool
oneM Mono i j
mx)
        checkZero Poly a i j
_             = Bool
True

    toConstVar :: Arithmetic a => Constraint a i -> Maybe (SysVar i, a)
    toConstVar :: forall a (i :: Type -> Type).
Arithmetic a =>
Constraint a i -> Maybe (SysVar i, a)
toConstVar = \case
      P [(a
_, M Map (SysVar i) Natural
m1)] -> case Map (SysVar i) Natural -> [(SysVar i, Natural)]
forall k a. Map k a -> [(k, a)]
toList Map (SysVar i) Natural
m1 of
        [(SysVar i
m1var, Natural
1)] -> (SysVar i, a) -> Maybe (SysVar i, a)
forall a. a -> Maybe a
Just (SysVar i
m1var, a
forall a. AdditiveMonoid a => a
zero)
        [(SysVar i, Natural)]
_            -> Maybe (SysVar i, a)
forall a. Maybe a
Nothing
      P [(a
c, M Map (SysVar i) Natural
m1), (a
k, M Map (SysVar i) Natural
m2)] ->
        if Mono (SysVar i) Natural -> Bool
forall i j. Mono i j -> Bool
oneM (Map (SysVar i) Natural -> Mono (SysVar i) Natural
forall i j. Map i j -> Mono i j
M Map (SysVar i) Natural
m1)
          then case Map (SysVar i) Natural -> [(SysVar i, Natural)]
forall k a. Map k a -> [(k, a)]
toList Map (SysVar i) Natural
m2 of
            [(SysVar i
m2var, Natural
1)] -> (SysVar i, a) -> Maybe (SysVar i, a)
forall a. a -> Maybe a
Just (SysVar i
m2var, a -> a
forall a. AdditiveGroup a => a -> a
negate a
c a -> a -> a
forall a. Field a => a -> a -> a
// a
k)
            [(SysVar i, Natural)]
_            -> Maybe (SysVar i, a)
forall a. Maybe a
Nothing
          else if Mono (SysVar i) Natural -> Bool
forall i j. Mono i j -> Bool
oneM (Map (SysVar i) Natural -> Mono (SysVar i) Natural
forall i j. Map i j -> Mono i j
M Map (SysVar i) Natural
m2)
            then case Map (SysVar i) Natural -> [(SysVar i, Natural)]
forall k a. Map k a -> [(k, a)]
toList Map (SysVar i) Natural
m1 of
              [(SysVar i
m1var, Natural
1)] -> (SysVar i, a) -> Maybe (SysVar i, a)
forall a. a -> Maybe a
Just (SysVar i
m1var, a -> a
forall a. AdditiveGroup a => a -> a
negate a
k a -> a -> a
forall a. Field a => a -> a -> a
// a
c)
              [(SysVar i, Natural)]
_            -> Maybe (SysVar i, a)
forall a. Maybe a
Nothing
            else Maybe (SysVar i, a)
forall a. Maybe a
Nothing
      Constraint a i
_ -> Maybe (SysVar i, a)
forall a. Maybe a
Nothing