{-# LANGUAGE DerivingStrategies   #-}
{-# LANGUAGE TypeOperators        #-}
{-# LANGUAGE UndecidableInstances #-}

module ZkFold.Symbolic.Data.Eq (
    Eq(..),
    elem,
    SymbolicEq,
    GEq (..)
) where

import           Data.Bool                        (bool)
import           Data.Foldable                    (Foldable)
import           Data.Functor.Rep                 (mzipRep, mzipWithRep)
import           Data.Traversable                 (for)
import qualified Data.Vector                      as V
import qualified GHC.Generics                     as G
import           Prelude                          (return, type (~), ($))
import qualified Prelude                          as Haskell

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Base.Algebra.Basic.Field
import           ZkFold.Base.Algebra.Basic.Number
import           ZkFold.Base.Data.Package
import           ZkFold.Base.Data.Vector
import           ZkFold.Symbolic.Class
import           ZkFold.Symbolic.Data.Bool        (Bool (Bool), BoolType (..), all, any)
import           ZkFold.Symbolic.Data.Class
import           ZkFold.Symbolic.Data.Combinators (runInvert)
import           ZkFold.Symbolic.Data.Conditional (Conditional, GConditional)
import           ZkFold.Symbolic.MonadCircuit

class Conditional (BooleanOf a) a => Eq a where

    type BooleanOf a
    type BooleanOf a = GBooleanOf (G.Rep a)

    infix 4 ==
    (==) :: a -> a -> BooleanOf a
    default (==)
      :: (G.Generic a, GEq (G.Rep a), BooleanOf a ~ GBooleanOf (G.Rep a))
      => a -> a -> BooleanOf a
    a
x == a
y = Rep a Any -> Rep a Any -> GBooleanOf (Rep a)
forall x. Rep a x -> Rep a x -> GBooleanOf (Rep a)
forall {k} (u :: k -> Type) (x :: k).
GEq u =>
u x -> u x -> GBooleanOf u
geq (a -> Rep a Any
forall x. a -> Rep a x
forall a x. Generic a => a -> Rep a x
G.from a
x) (a -> Rep a Any
forall x. a -> Rep a x
forall a x. Generic a => a -> Rep a x
G.from a
y)

    infix 4 /=
    (/=) :: a -> a -> BooleanOf a
    default (/=)
      :: (G.Generic a, GEq (G.Rep a), BooleanOf a ~ GBooleanOf (G.Rep a))
      => a -> a -> BooleanOf a
    a
x /= a
y = Rep a Any -> Rep a Any -> GBooleanOf (Rep a)
forall x. Rep a x -> Rep a x -> GBooleanOf (Rep a)
forall {k} (u :: k -> Type) (x :: k).
GEq u =>
u x -> u x -> GBooleanOf u
gneq (a -> Rep a Any
forall x. a -> Rep a x
forall a x. Generic a => a -> Rep a x
G.from a
x) (a -> Rep a Any
forall x. a -> Rep a x
forall a x. Generic a => a -> Rep a x
G.from a
y)

elem :: (Eq a, Foldable t) => a -> t a -> BooleanOf a
elem :: forall a (t :: Type -> Type).
(Eq a, Foldable t) =>
a -> t a -> BooleanOf a
elem a
x = (a -> BooleanOf a) -> t a -> BooleanOf a
forall b (t :: Type -> Type) x.
(BoolType b, Foldable t) =>
(x -> b) -> t x -> b
any (a -> a -> BooleanOf a
forall a. Eq a => a -> a -> BooleanOf a
== a
x)

instance Eq Natural where
    type BooleanOf Natural = Haskell.Bool
    == :: Natural -> Natural -> BooleanOf Natural
(==) = Natural -> Natural -> Bool
Natural -> Natural -> BooleanOf Natural
forall a. Eq a => a -> a -> Bool
(Haskell.==)
    /= :: Natural -> Natural -> BooleanOf Natural
(/=) = Natural -> Natural -> Bool
Natural -> Natural -> BooleanOf Natural
forall a. Eq a => a -> a -> Bool
(Haskell./=)
instance Eq Haskell.Bool where
    type BooleanOf Haskell.Bool = Haskell.Bool
    == :: Bool -> Bool -> BooleanOf Bool
(==) = Bool -> Bool -> Bool
Bool -> Bool -> BooleanOf Bool
forall a. Eq a => a -> a -> Bool
(Haskell.==)
    /= :: Bool -> Bool -> BooleanOf Bool
(/=) = Bool -> Bool -> Bool
Bool -> Bool -> BooleanOf Bool
forall a. Eq a => a -> a -> Bool
(Haskell./=)
instance Eq Haskell.String where
    type BooleanOf Haskell.String = Haskell.Bool
    == :: String -> String -> BooleanOf String
(==) = String -> String -> Bool
String -> String -> BooleanOf String
forall a. Eq a => a -> a -> Bool
(Haskell.==)
    /= :: String -> String -> BooleanOf String
(/=) = String -> String -> Bool
String -> String -> BooleanOf String
forall a. Eq a => a -> a -> Bool
(Haskell./=)
instance KnownNat n => Eq (Zp n) where
    type BooleanOf (Zp n) = Haskell.Bool
    == :: Zp n -> Zp n -> BooleanOf (Zp n)
(==) = Zp n -> Zp n -> Bool
Zp n -> Zp n -> BooleanOf (Zp n)
forall a. Eq a => a -> a -> Bool
(Haskell.==)
    /= :: Zp n -> Zp n -> BooleanOf (Zp n)
(/=) = Zp n -> Zp n -> Bool
Zp n -> Zp n -> BooleanOf (Zp n)
forall a. Eq a => a -> a -> Bool
(Haskell./=)

instance (Symbolic c, LayoutFunctor f)
  => Eq (c f) where
    type BooleanOf (c f) = Bool c
    c f
x == :: c f -> c f -> BooleanOf (c f)
== c f
y =
        let
            result :: c f
result = c f
-> c f
-> (f (BaseField c) -> f (BaseField c) -> f (BaseField c))
-> CircuitFun '[f, f] f c
-> c f
forall (c :: (Type -> Type) -> Type) a (f :: Type -> Type)
       (g :: Type -> Type) (h :: Type -> Type).
(Symbolic c, BaseField c ~ a) =>
c f -> c g -> (f a -> g a -> h a) -> CircuitFun '[f, g] h c -> c h
symbolic2F c f
x c f
y
                ((BaseField c -> BaseField c -> BaseField c)
-> f (BaseField c) -> f (BaseField c) -> f (BaseField c)
forall (f :: Type -> Type) a b c.
Representable f =>
(a -> b -> c) -> f a -> f b -> f c
mzipWithRep (\BaseField c
i BaseField c
j -> BaseField c -> BaseField c -> Bool -> BaseField c
forall a. a -> a -> Bool -> a
bool BaseField c
forall a. AdditiveMonoid a => a
zero BaseField c
forall a. MultiplicativeMonoid a => a
one (BaseField c
i BaseField c -> BaseField c -> Bool
forall a. Eq a => a -> a -> Bool
Haskell.== BaseField c
j)))
                (\f i
x' f i
y' -> do
                    f i
difference <- f (i, i) -> ((i, i) -> m i) -> m (f i)
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for (f i -> f i -> f (i, i)
forall (f :: Type -> Type) a b.
Representable f =>
f a -> f b -> f (a, b)
mzipRep f i
x' f i
y') (((i, i) -> m i) -> m (f i)) -> ((i, i) -> m i) -> m (f i)
forall a b. (a -> b) -> a -> b
$ \(i
i, i
j) ->
                        ClosedPoly i (BaseField c) -> m i
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
ClosedPoly var a -> m var
newAssigned (\i -> x
w -> i -> x
w i
i x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- i -> x
w i
j)
                    (f i
isZeros, f i
_) <- f i -> m (f i, f i)
forall i a w (m :: Type -> Type) (f :: Type -> Type).
(MonadCircuit i a w m, Representable f, Traversable f) =>
f i -> m (f i, f i)
runInvert f i
difference
                    f i -> m (f i)
forall a. a -> m a
forall (m :: Type -> Type) a. Monad m => a -> m a
return f i
isZeros
                )
        in
            (c Par1 -> Bool c) -> f (c Par1) -> Bool c
forall b (t :: Type -> Type) x.
(BoolType b, Foldable t) =>
(x -> b) -> t x -> b
all c Par1 -> Bool c
forall (c :: (Type -> Type) -> Type). c Par1 -> Bool c
Bool (c f -> f (c Par1)
forall (c :: (Type -> Type) -> Type) (f :: Type -> Type).
(Package c, Functor f) =>
c f -> f (c Par1)
unpacked c f
result)

    c f
x /= :: c f -> c f -> BooleanOf (c f)
/= c f
y =
        let
            result :: c f
result = c f
-> c f
-> (f (BaseField c) -> f (BaseField c) -> f (BaseField c))
-> CircuitFun '[f, f] f c
-> c f
forall (c :: (Type -> Type) -> Type) a (f :: Type -> Type)
       (g :: Type -> Type) (h :: Type -> Type).
(Symbolic c, BaseField c ~ a) =>
c f -> c g -> (f a -> g a -> h a) -> CircuitFun '[f, g] h c -> c h
symbolic2F c f
x c f
y
                ((BaseField c -> BaseField c -> BaseField c)
-> f (BaseField c) -> f (BaseField c) -> f (BaseField c)
forall (f :: Type -> Type) a b c.
Representable f =>
(a -> b -> c) -> f a -> f b -> f c
mzipWithRep (\BaseField c
i BaseField c
j -> BaseField c -> BaseField c -> Bool -> BaseField c
forall a. a -> a -> Bool -> a
bool BaseField c
forall a. AdditiveMonoid a => a
zero BaseField c
forall a. MultiplicativeMonoid a => a
one (BaseField c
i BaseField c -> BaseField c -> Bool
forall a. Eq a => a -> a -> Bool
Haskell./= BaseField c
j)))
                (\f i
x' f i
y' -> do
                    f i
difference <- f (i, i) -> ((i, i) -> m i) -> m (f i)
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for (f i -> f i -> f (i, i)
forall (f :: Type -> Type) a b.
Representable f =>
f a -> f b -> f (a, b)
mzipRep f i
x' f i
y') (((i, i) -> m i) -> m (f i)) -> ((i, i) -> m i) -> m (f i)
forall a b. (a -> b) -> a -> b
$ \(i
i, i
j) ->
                        ClosedPoly i (BaseField c) -> m i
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
ClosedPoly var a -> m var
newAssigned (\i -> x
w -> i -> x
w i
i x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- i -> x
w i
j)
                    (f i
isZeros, f i
_) <- f i -> m (f i, f i)
forall i a w (m :: Type -> Type) (f :: Type -> Type).
(MonadCircuit i a w m, Representable f, Traversable f) =>
f i -> m (f i, f i)
runInvert f i
difference
                    f i -> (i -> m i) -> m (f i)
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for f i
isZeros ((i -> m i) -> m (f i)) -> (i -> m i) -> m (f i)
forall a b. (a -> b) -> a -> b
$ \i
isZ ->
                      ClosedPoly i (BaseField c) -> m i
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
ClosedPoly var a -> m var
newAssigned (\i -> x
w -> x
forall a. MultiplicativeMonoid a => a
one x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- i -> x
w i
isZ)
                )
        in
            (c Par1 -> Bool c) -> f (c Par1) -> Bool c
forall b (t :: Type -> Type) x.
(BoolType b, Foldable t) =>
(x -> b) -> t x -> b
any c Par1 -> Bool c
forall (c :: (Type -> Type) -> Type). c Par1 -> Bool c
Bool (c f -> f (c Par1)
forall (c :: (Type -> Type) -> Type) (f :: Type -> Type).
(Package c, Functor f) =>
c f -> f (c Par1)
unpacked c f
result)

type SymbolicEq x =
  ( SymbolicOutput x
  , Eq x
  , BooleanOf x ~ Bool (Context x)
  )

instance (KnownNat n, Eq x) => Eq (Vector n x) where
    type BooleanOf (Vector n x) = BooleanOf x
    Vector n x
u == :: Vector n x -> Vector n x -> BooleanOf (Vector n x)
== Vector n x
v = (BooleanOf x -> BooleanOf x -> BooleanOf x)
-> BooleanOf x -> Vector (BooleanOf x) -> BooleanOf x
forall a b. (a -> b -> a) -> a -> Vector b -> a
V.foldl BooleanOf x -> BooleanOf x -> BooleanOf x
forall b. BoolType b => b -> b -> b
(&&) BooleanOf x
forall b. BoolType b => b
true ((x -> x -> BooleanOf x)
-> Vector x -> Vector x -> Vector (BooleanOf x)
forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith x -> x -> BooleanOf x
forall a. Eq a => a -> a -> BooleanOf a
(==) (Vector n x -> Vector x
forall (size :: Natural) a. Vector size a -> Vector a
toV Vector n x
u) (Vector n x -> Vector x
forall (size :: Natural) a. Vector size a -> Vector a
toV Vector n x
v))
    Vector n x
u /= :: Vector n x -> Vector n x -> BooleanOf (Vector n x)
/= Vector n x
v = (BooleanOf x -> BooleanOf x -> BooleanOf x)
-> BooleanOf x -> Vector (BooleanOf x) -> BooleanOf x
forall a b. (a -> b -> a) -> a -> Vector b -> a
V.foldl BooleanOf x -> BooleanOf x -> BooleanOf x
forall b. BoolType b => b -> b -> b
(||) BooleanOf x
forall b. BoolType b => b
false ((x -> x -> BooleanOf x)
-> Vector x -> Vector x -> Vector (BooleanOf x)
forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith x -> x -> BooleanOf x
forall a. Eq a => a -> a -> BooleanOf a
(/=) (Vector n x -> Vector x
forall (size :: Natural) a. Vector size a -> Vector a
toV Vector n x
u) (Vector n x -> Vector x
forall (size :: Natural) a. Vector size a -> Vector a
toV Vector n x
v))

deriving newtype instance Symbolic c => Eq (Bool c)

instance (Eq x0, Eq x1, BooleanOf x0 ~ BooleanOf x1) => Eq (x0,x1)
instance
  ( Eq x0, Eq x1, Eq x2
  , BooleanOf x0 ~ BooleanOf x1
  , BooleanOf x1 ~ BooleanOf x2
  ) => Eq (x0,x1,x2)
instance
  ( Eq x0, Eq x1, Eq x2, Eq x3
  , BooleanOf x0 ~ BooleanOf x1
  , BooleanOf x1 ~ BooleanOf x2
  , BooleanOf x2 ~ BooleanOf x3
  ) => Eq (x0,x1,x2,x3)

instance Eq field => Eq (Ext2 field i)
instance Eq field => Eq (Ext3 field i)

class GConditional (GBooleanOf u) u => GEq u where
    type GBooleanOf u
    geq :: u x -> u x -> GBooleanOf u
    gneq :: u x -> u x -> GBooleanOf u

instance (GBooleanOf u ~ GBooleanOf v, GEq u, GEq v) => GEq (u G.:*: v) where
    type GBooleanOf (u G.:*: v) = GBooleanOf u
    geq :: forall (x :: k). (:*:) u v x -> (:*:) u v x -> GBooleanOf (u :*: v)
geq (u x
x0 G.:*: v x
x1) (u x
y0 G.:*: v x
y1) = u x -> u x -> GBooleanOf u
forall (x :: k). u x -> u x -> GBooleanOf u
forall {k} (u :: k -> Type) (x :: k).
GEq u =>
u x -> u x -> GBooleanOf u
geq u x
x0 u x
y0 GBooleanOf v -> GBooleanOf v -> GBooleanOf v
forall b. BoolType b => b -> b -> b
&& v x -> v x -> GBooleanOf v
forall (x :: k). v x -> v x -> GBooleanOf v
forall {k} (u :: k -> Type) (x :: k).
GEq u =>
u x -> u x -> GBooleanOf u
geq v x
x1 v x
y1
    gneq :: forall (x :: k). (:*:) u v x -> (:*:) u v x -> GBooleanOf (u :*: v)
gneq (u x
x0 G.:*: v x
x1) (u x
y0 G.:*: v x
y1) = u x -> u x -> GBooleanOf u
forall (x :: k). u x -> u x -> GBooleanOf u
forall {k} (u :: k -> Type) (x :: k).
GEq u =>
u x -> u x -> GBooleanOf u
gneq u x
x0 u x
y0 GBooleanOf v -> GBooleanOf v -> GBooleanOf v
forall b. BoolType b => b -> b -> b
|| v x -> v x -> GBooleanOf v
forall (x :: k). v x -> v x -> GBooleanOf v
forall {k} (u :: k -> Type) (x :: k).
GEq u =>
u x -> u x -> GBooleanOf u
gneq v x
x1 v x
y1

instance GEq v => GEq (G.M1 i c v) where
    type GBooleanOf (G.M1 i c v) = GBooleanOf v
    geq :: forall (x :: k). M1 i c v x -> M1 i c v x -> GBooleanOf (M1 i c v)
geq (G.M1 v x
x) (G.M1 v x
y) = v x -> v x -> GBooleanOf v
forall (x :: k). v x -> v x -> GBooleanOf v
forall {k} (u :: k -> Type) (x :: k).
GEq u =>
u x -> u x -> GBooleanOf u
geq v x
x v x
y
    gneq :: forall (x :: k). M1 i c v x -> M1 i c v x -> GBooleanOf (M1 i c v)
gneq (G.M1 v x
x) (G.M1 v x
y) = v x -> v x -> GBooleanOf v
forall (x :: k). v x -> v x -> GBooleanOf v
forall {k} (u :: k -> Type) (x :: k).
GEq u =>
u x -> u x -> GBooleanOf u
gneq v x
x v x
y

instance Eq x => GEq (G.Rec0 x) where
    type GBooleanOf (G.Rec0 x) = BooleanOf x
    geq :: forall (x :: k). Rec0 x x -> Rec0 x x -> GBooleanOf (Rec0 x)
geq (G.K1 x
x) (G.K1 x
y) = x
x x -> x -> BooleanOf x
forall a. Eq a => a -> a -> BooleanOf a
== x
y
    gneq :: forall (x :: k). Rec0 x x -> Rec0 x x -> GBooleanOf (Rec0 x)
gneq (G.K1 x
x) (G.K1 x
y) = x
x x -> x -> BooleanOf x
forall a. Eq a => a -> a -> BooleanOf a
/= x
y