{-# LANGUAGE DerivingStrategies   #-}
{-# LANGUAGE TypeOperators        #-}
{-# LANGUAGE UndecidableInstances #-}

module ZkFold.Symbolic.Data.Eq (
    Eq(..),
    elem,
    GEq (..)
) where

import           Data.Bool                        (bool)
import           Data.Foldable                    (Foldable)
import           Data.Functor.Rep                 (Representable, mzipRep, mzipWithRep)
import           Data.Traversable                 (Traversable, for)
import qualified Data.Vector                      as V
import qualified GHC.Generics                     as G
import           Prelude                          (return, ($))
import qualified Prelude                          as Haskell

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Base.Data.Package
import           ZkFold.Base.Data.Vector
import           ZkFold.Symbolic.Class
import           ZkFold.Symbolic.Data.Bool        (Bool (Bool), BoolType (..), all, any)
import           ZkFold.Symbolic.Data.Combinators (runInvert)
import           ZkFold.Symbolic.MonadCircuit

class Eq b a where
    infix 4 ==
    (==) :: a -> a -> b
    default (==) :: (G.Generic a, GEq b (G.Rep a)) => a -> a -> b
    a
x == a
y = Rep a Any -> Rep a Any -> b
forall x. Rep a x -> Rep a x -> b
forall {k} b (u :: k -> Type) (x :: k). GEq b u => u x -> u x -> b
geq (a -> Rep a Any
forall x. a -> Rep a x
forall a x. Generic a => a -> Rep a x
G.from a
x) (a -> Rep a Any
forall x. a -> Rep a x
forall a x. Generic a => a -> Rep a x
G.from a
y)

    infix 4 /=
    (/=) :: a -> a -> b
    default (/=) :: (G.Generic a, GEq b (G.Rep a)) => a -> a -> b
    a
x /= a
y = Rep a Any -> Rep a Any -> b
forall x. Rep a x -> Rep a x -> b
forall {k} b (u :: k -> Type) (x :: k). GEq b u => u x -> u x -> b
gneq (a -> Rep a Any
forall x. a -> Rep a x
forall a x. Generic a => a -> Rep a x
G.from a
x) (a -> Rep a Any
forall x. a -> Rep a x
forall a x. Generic a => a -> Rep a x
G.from a
y)

elem :: (BoolType b, Eq b a, Foldable t) => a -> t a -> b
elem :: forall b a (t :: Type -> Type).
(BoolType b, Eq b a, Foldable t) =>
a -> t a -> b
elem a
x = (a -> b) -> t a -> b
forall b (t :: Type -> Type) x.
(BoolType b, Foldable t) =>
(x -> b) -> t x -> b
any (a -> a -> b
forall b a. Eq b a => a -> a -> b
== a
x)

instance Haskell.Eq a => Eq Haskell.Bool a where
    == :: a -> a -> Bool
(==) = a -> a -> Bool
forall a. Eq a => a -> a -> Bool
(Haskell.==)
    /= :: a -> a -> Bool
(/=) = a -> a -> Bool
forall a. Eq a => a -> a -> Bool
(Haskell./=)

instance (Symbolic c, Haskell.Eq (BaseField c), Representable f, Traversable f)
  => Eq (Bool c) (c f) where
    c f
x == :: c f -> c f -> Bool c
== c f
y =
        let
            result :: c f
result = c f
-> c f
-> (f (BaseField c) -> f (BaseField c) -> f (BaseField c))
-> CircuitFun '[f, f] f c
-> c f
forall (c :: (Type -> Type) -> Type) a (f :: Type -> Type)
       (g :: Type -> Type) (h :: Type -> Type).
(Symbolic c, BaseField c ~ a) =>
c f -> c g -> (f a -> g a -> h a) -> CircuitFun '[f, g] h c -> c h
symbolic2F c f
x c f
y
                ((BaseField c -> BaseField c -> BaseField c)
-> f (BaseField c) -> f (BaseField c) -> f (BaseField c)
forall (f :: Type -> Type) a b c.
Representable f =>
(a -> b -> c) -> f a -> f b -> f c
mzipWithRep (\BaseField c
i BaseField c
j -> BaseField c -> BaseField c -> Bool -> BaseField c
forall a. a -> a -> Bool -> a
bool BaseField c
forall a. AdditiveMonoid a => a
zero BaseField c
forall a. MultiplicativeMonoid a => a
one (BaseField c
i BaseField c -> BaseField c -> Bool
forall a. Eq a => a -> a -> Bool
Haskell.== BaseField c
j)))
                (\f i
x' f i
y' -> do
                    f i
difference <- f (i, i) -> ((i, i) -> m i) -> m (f i)
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for (f i -> f i -> f (i, i)
forall (f :: Type -> Type) a b.
Representable f =>
f a -> f b -> f (a, b)
mzipRep f i
x' f i
y') (((i, i) -> m i) -> m (f i)) -> ((i, i) -> m i) -> m (f i)
forall a b. (a -> b) -> a -> b
$ \(i
i, i
j) ->
                        ClosedPoly i (BaseField c) -> m i
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
ClosedPoly var a -> m var
newAssigned (\i -> x
w -> i -> x
w i
i x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- i -> x
w i
j)
                    (f i
isZeros, f i
_) <- f i -> m (f i, f i)
forall i a w (m :: Type -> Type) (f :: Type -> Type).
(MonadCircuit i a w m, Representable f, Traversable f) =>
f i -> m (f i, f i)
runInvert f i
difference
                    f i -> m (f i)
forall a. a -> m a
forall (m :: Type -> Type) a. Monad m => a -> m a
return f i
isZeros
                )
        in
            (c Par1 -> Bool c) -> f (c Par1) -> Bool c
forall b (t :: Type -> Type) x.
(BoolType b, Foldable t) =>
(x -> b) -> t x -> b
all c Par1 -> Bool c
forall (c :: (Type -> Type) -> Type). c Par1 -> Bool c
Bool (c f -> f (c Par1)
forall (c :: (Type -> Type) -> Type) (f :: Type -> Type).
(Package c, Functor f) =>
c f -> f (c Par1)
unpacked c f
result)

    c f
x /= :: c f -> c f -> Bool c
/= c f
y =
        let
            result :: c f
result = c f
-> c f
-> (f (BaseField c) -> f (BaseField c) -> f (BaseField c))
-> CircuitFun '[f, f] f c
-> c f
forall (c :: (Type -> Type) -> Type) a (f :: Type -> Type)
       (g :: Type -> Type) (h :: Type -> Type).
(Symbolic c, BaseField c ~ a) =>
c f -> c g -> (f a -> g a -> h a) -> CircuitFun '[f, g] h c -> c h
symbolic2F c f
x c f
y
                ((BaseField c -> BaseField c -> BaseField c)
-> f (BaseField c) -> f (BaseField c) -> f (BaseField c)
forall (f :: Type -> Type) a b c.
Representable f =>
(a -> b -> c) -> f a -> f b -> f c
mzipWithRep (\BaseField c
i BaseField c
j -> BaseField c -> BaseField c -> Bool -> BaseField c
forall a. a -> a -> Bool -> a
bool BaseField c
forall a. AdditiveMonoid a => a
zero BaseField c
forall a. MultiplicativeMonoid a => a
one (BaseField c
i BaseField c -> BaseField c -> Bool
forall a. Eq a => a -> a -> Bool
Haskell./= BaseField c
j)))
                (\f i
x' f i
y' -> do
                    f i
difference <- f (i, i) -> ((i, i) -> m i) -> m (f i)
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for (f i -> f i -> f (i, i)
forall (f :: Type -> Type) a b.
Representable f =>
f a -> f b -> f (a, b)
mzipRep f i
x' f i
y') (((i, i) -> m i) -> m (f i)) -> ((i, i) -> m i) -> m (f i)
forall a b. (a -> b) -> a -> b
$ \(i
i, i
j) ->
                        ClosedPoly i (BaseField c) -> m i
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
ClosedPoly var a -> m var
newAssigned (\i -> x
w -> i -> x
w i
i x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- i -> x
w i
j)
                    (f i
isZeros, f i
_) <- f i -> m (f i, f i)
forall i a w (m :: Type -> Type) (f :: Type -> Type).
(MonadCircuit i a w m, Representable f, Traversable f) =>
f i -> m (f i, f i)
runInvert f i
difference
                    f i -> (i -> m i) -> m (f i)
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for f i
isZeros ((i -> m i) -> m (f i)) -> (i -> m i) -> m (f i)
forall a b. (a -> b) -> a -> b
$ \i
isZ ->
                      ClosedPoly i (BaseField c) -> m i
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
ClosedPoly var a -> m var
newAssigned (\i -> x
w -> x
forall a. MultiplicativeMonoid a => a
one x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- i -> x
w i
isZ)
                )
        in
            (c Par1 -> Bool c) -> f (c Par1) -> Bool c
forall b (t :: Type -> Type) x.
(BoolType b, Foldable t) =>
(x -> b) -> t x -> b
any c Par1 -> Bool c
forall (c :: (Type -> Type) -> Type). c Par1 -> Bool c
Bool (c f -> f (c Par1)
forall (c :: (Type -> Type) -> Type) (f :: Type -> Type).
(Package c, Functor f) =>
c f -> f (c Par1)
unpacked c f
result)

instance (BoolType b, Eq b x) => Eq b (Vector n x) where
    Vector n x
u == :: Vector n x -> Vector n x -> b
== Vector n x
v = (b -> b -> b) -> b -> Vector b -> b
forall a b. (a -> b -> a) -> a -> Vector b -> a
V.foldl b -> b -> b
forall b. BoolType b => b -> b -> b
(&&) b
forall b. BoolType b => b
true ((x -> x -> b) -> Vector x -> Vector x -> Vector b
forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith x -> x -> b
forall b a. Eq b a => a -> a -> b
(==) (Vector n x -> Vector x
forall (size :: Natural) a. Vector size a -> Vector a
toV Vector n x
u) (Vector n x -> Vector x
forall (size :: Natural) a. Vector size a -> Vector a
toV Vector n x
v))
    Vector n x
u /= :: Vector n x -> Vector n x -> b
/= Vector n x
v = (b -> b -> b) -> b -> Vector b -> b
forall a b. (a -> b -> a) -> a -> Vector b -> a
V.foldl b -> b -> b
forall b. BoolType b => b -> b -> b
(||) b
forall b. BoolType b => b
false ((x -> x -> b) -> Vector x -> Vector x -> Vector b
forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith x -> x -> b
forall b a. Eq b a => a -> a -> b
(/=) (Vector n x -> Vector x
forall (size :: Natural) a. Vector size a -> Vector a
toV Vector n x
u) (Vector n x -> Vector x
forall (size :: Natural) a. Vector size a -> Vector a
toV Vector n x
v))

deriving newtype instance Symbolic c => Eq (Bool c) (Bool c)

instance (BoolType b, Eq b x0, Eq b x1) => Eq b (x0,x1)
instance (BoolType b, Eq b x0, Eq b x1, Eq b x2) => Eq b (x0,x1,x2)
instance (BoolType b, Eq b x0, Eq b x1, Eq b x2, Eq b x3) => Eq b (x0,x1,x2,x3)

class GEq b u where
    geq :: u x -> u x -> b
    gneq :: u x -> u x -> b

instance (BoolType b, GEq b u, GEq b v) => GEq b (u G.:*: v) where
    geq :: forall (x :: k). (:*:) u v x -> (:*:) u v x -> b
geq (u x
x0 G.:*: v x
x1) (u x
y0 G.:*: v x
y1) = u x -> u x -> b
forall (x :: k). u x -> u x -> b
forall {k} b (u :: k -> Type) (x :: k). GEq b u => u x -> u x -> b
geq u x
x0 u x
y0 b -> b -> b
forall b. BoolType b => b -> b -> b
&& v x -> v x -> b
forall (x :: k). v x -> v x -> b
forall {k} b (u :: k -> Type) (x :: k). GEq b u => u x -> u x -> b
geq v x
x1 v x
y1
    gneq :: forall (x :: k). (:*:) u v x -> (:*:) u v x -> b
gneq (u x
x0 G.:*: v x
x1) (u x
y0 G.:*: v x
y1) = u x -> u x -> b
forall (x :: k). u x -> u x -> b
forall {k} b (u :: k -> Type) (x :: k). GEq b u => u x -> u x -> b
gneq u x
x0 u x
y0 b -> b -> b
forall b. BoolType b => b -> b -> b
|| v x -> v x -> b
forall (x :: k). v x -> v x -> b
forall {k} b (u :: k -> Type) (x :: k). GEq b u => u x -> u x -> b
gneq v x
x1 v x
y1

instance GEq b v => GEq b (G.M1 i c v) where
    geq :: forall (x :: k). M1 i c v x -> M1 i c v x -> b
geq (G.M1 v x
x) (G.M1 v x
y) = v x -> v x -> b
forall (x :: k). v x -> v x -> b
forall {k} b (u :: k -> Type) (x :: k). GEq b u => u x -> u x -> b
geq v x
x v x
y
    gneq :: forall (x :: k). M1 i c v x -> M1 i c v x -> b
gneq (G.M1 v x
x) (G.M1 v x
y) = v x -> v x -> b
forall (x :: k). v x -> v x -> b
forall {k} b (u :: k -> Type) (x :: k). GEq b u => u x -> u x -> b
gneq v x
x v x
y

instance Eq b x => GEq b (G.Rec0 x) where
    geq :: forall (x :: k). Rec0 x x -> Rec0 x x -> b
geq (G.K1 x
x) (G.K1 x
y) = x
x x -> x -> b
forall b a. Eq b a => a -> a -> b
== x
y
    gneq :: forall (x :: k). Rec0 x x -> Rec0 x x -> b
gneq (G.K1 x
x) (G.K1 x
y) = x
x x -> x -> b
forall b a. Eq b a => a -> a -> b
/= x
y