{-# LANGUAGE DerivingStrategies   #-}
{-# LANGUAGE TypeOperators        #-}
{-# LANGUAGE UndecidableInstances #-}

module ZkFold.Symbolic.Data.Eq (
    Eq(..),
    elem,
    GEq (..)
) where

import           Data.Bool                        (bool)
import           Data.Foldable                    (Foldable)
import           Data.Functor.Rep                 (Representable, mzipRep, mzipWithRep)
import           Data.Traversable                 (Traversable, for)
import qualified Data.Vector                      as V
import qualified GHC.Generics                     as G
import           Prelude                          (return, ($))
import qualified Prelude                          as Haskell

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Base.Algebra.Basic.Field
import           ZkFold.Base.Algebra.Basic.Number
import           ZkFold.Base.Data.Package
import           ZkFold.Base.Data.Vector
import           ZkFold.Symbolic.Class
import           ZkFold.Symbolic.Data.Bool        (Bool (Bool), BoolType (..), all, any)
import           ZkFold.Symbolic.Data.Combinators (runInvert)
import           ZkFold.Symbolic.MonadCircuit

class Eq b a where
    infix 4 ==
    (==) :: a -> a -> b
    default (==) :: (G.Generic a, GEq b (G.Rep a)) => a -> a -> b
    a
x == a
y = Rep a Any -> Rep a Any -> b
forall x. Rep a x -> Rep a x -> b
forall {k} b (u :: k -> Type) (x :: k). GEq b u => u x -> u x -> b
geq (a -> Rep a Any
forall x. a -> Rep a x
forall a x. Generic a => a -> Rep a x
G.from a
x) (a -> Rep a Any
forall x. a -> Rep a x
forall a x. Generic a => a -> Rep a x
G.from a
y)

    infix 4 /=
    (/=) :: a -> a -> b
    default (/=) :: (G.Generic a, GEq b (G.Rep a)) => a -> a -> b
    a
x /= a
y = Rep a Any -> Rep a Any -> b
forall x. Rep a x -> Rep a x -> b
forall {k} b (u :: k -> Type) (x :: k). GEq b u => u x -> u x -> b
gneq (a -> Rep a Any
forall x. a -> Rep a x
forall a x. Generic a => a -> Rep a x
G.from a
x) (a -> Rep a Any
forall x. a -> Rep a x
forall a x. Generic a => a -> Rep a x
G.from a
y)

elem :: (BoolType b, Eq b a, Foldable t) => a -> t a -> b
elem :: forall b a (t :: Type -> Type).
(BoolType b, Eq b a, Foldable t) =>
a -> t a -> b
elem a
x = (a -> b) -> t a -> b
forall b (t :: Type -> Type) x.
(BoolType b, Foldable t) =>
(x -> b) -> t x -> b
any (a -> a -> b
forall b a. Eq b a => a -> a -> b
== a
x)

instance Eq Haskell.Bool Natural where
    == :: Natural -> Natural -> Bool
(==) = Natural -> Natural -> Bool
forall a. Eq a => a -> a -> Bool
(Haskell.==)
    /= :: Natural -> Natural -> Bool
(/=) = Natural -> Natural -> Bool
forall a. Eq a => a -> a -> Bool
(Haskell./=)
instance Eq Haskell.Bool Haskell.Bool where
    == :: Bool -> Bool -> Bool
(==) = Bool -> Bool -> Bool
forall a. Eq a => a -> a -> Bool
(Haskell.==)
    /= :: Bool -> Bool -> Bool
(/=) = Bool -> Bool -> Bool
forall a. Eq a => a -> a -> Bool
(Haskell./=)
instance Eq Haskell.Bool Haskell.String where
    == :: String -> String -> Bool
(==) = String -> String -> Bool
forall a. Eq a => a -> a -> Bool
(Haskell.==)
    /= :: String -> String -> Bool
(/=) = String -> String -> Bool
forall a. Eq a => a -> a -> Bool
(Haskell./=)
instance KnownNat n => Eq Haskell.Bool (Zp n) where
    == :: Zp n -> Zp n -> Bool
(==) = Zp n -> Zp n -> Bool
forall a. Eq a => a -> a -> Bool
(Haskell.==)
    /= :: Zp n -> Zp n -> Bool
(/=) = Zp n -> Zp n -> Bool
forall a. Eq a => a -> a -> Bool
(Haskell./=)

instance (Symbolic c, Haskell.Eq (BaseField c), Representable f, Traversable f)
  => Eq (Bool c) (c f) where
    c f
x == :: c f -> c f -> Bool c
== c f
y =
        let
            result :: c f
result = c f
-> c f
-> (f (BaseField c) -> f (BaseField c) -> f (BaseField c))
-> CircuitFun '[f, f] f c
-> c f
forall (c :: (Type -> Type) -> Type) a (f :: Type -> Type)
       (g :: Type -> Type) (h :: Type -> Type).
(Symbolic c, BaseField c ~ a) =>
c f -> c g -> (f a -> g a -> h a) -> CircuitFun '[f, g] h c -> c h
symbolic2F c f
x c f
y
                ((BaseField c -> BaseField c -> BaseField c)
-> f (BaseField c) -> f (BaseField c) -> f (BaseField c)
forall (f :: Type -> Type) a b c.
Representable f =>
(a -> b -> c) -> f a -> f b -> f c
mzipWithRep (\BaseField c
i BaseField c
j -> BaseField c -> BaseField c -> Bool -> BaseField c
forall a. a -> a -> Bool -> a
bool BaseField c
forall a. AdditiveMonoid a => a
zero BaseField c
forall a. MultiplicativeMonoid a => a
one (BaseField c
i BaseField c -> BaseField c -> Bool
forall a. Eq a => a -> a -> Bool
Haskell.== BaseField c
j)))
                (\f i
x' f i
y' -> do
                    f i
difference <- f (i, i) -> ((i, i) -> m i) -> m (f i)
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for (f i -> f i -> f (i, i)
forall (f :: Type -> Type) a b.
Representable f =>
f a -> f b -> f (a, b)
mzipRep f i
x' f i
y') (((i, i) -> m i) -> m (f i)) -> ((i, i) -> m i) -> m (f i)
forall a b. (a -> b) -> a -> b
$ \(i
i, i
j) ->
                        ClosedPoly i (BaseField c) -> m i
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
ClosedPoly var a -> m var
newAssigned (\i -> x
w -> i -> x
w i
i x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- i -> x
w i
j)
                    (f i
isZeros, f i
_) <- f i -> m (f i, f i)
forall i a w (m :: Type -> Type) (f :: Type -> Type).
(MonadCircuit i a w m, Representable f, Traversable f) =>
f i -> m (f i, f i)
runInvert f i
difference
                    f i -> m (f i)
forall a. a -> m a
forall (m :: Type -> Type) a. Monad m => a -> m a
return f i
isZeros
                )
        in
            (c Par1 -> Bool c) -> f (c Par1) -> Bool c
forall b (t :: Type -> Type) x.
(BoolType b, Foldable t) =>
(x -> b) -> t x -> b
all c Par1 -> Bool c
forall (c :: (Type -> Type) -> Type). c Par1 -> Bool c
Bool (c f -> f (c Par1)
forall (c :: (Type -> Type) -> Type) (f :: Type -> Type).
(Package c, Functor f) =>
c f -> f (c Par1)
unpacked c f
result)

    c f
x /= :: c f -> c f -> Bool c
/= c f
y =
        let
            result :: c f
result = c f
-> c f
-> (f (BaseField c) -> f (BaseField c) -> f (BaseField c))
-> CircuitFun '[f, f] f c
-> c f
forall (c :: (Type -> Type) -> Type) a (f :: Type -> Type)
       (g :: Type -> Type) (h :: Type -> Type).
(Symbolic c, BaseField c ~ a) =>
c f -> c g -> (f a -> g a -> h a) -> CircuitFun '[f, g] h c -> c h
symbolic2F c f
x c f
y
                ((BaseField c -> BaseField c -> BaseField c)
-> f (BaseField c) -> f (BaseField c) -> f (BaseField c)
forall (f :: Type -> Type) a b c.
Representable f =>
(a -> b -> c) -> f a -> f b -> f c
mzipWithRep (\BaseField c
i BaseField c
j -> BaseField c -> BaseField c -> Bool -> BaseField c
forall a. a -> a -> Bool -> a
bool BaseField c
forall a. AdditiveMonoid a => a
zero BaseField c
forall a. MultiplicativeMonoid a => a
one (BaseField c
i BaseField c -> BaseField c -> Bool
forall a. Eq a => a -> a -> Bool
Haskell./= BaseField c
j)))
                (\f i
x' f i
y' -> do
                    f i
difference <- f (i, i) -> ((i, i) -> m i) -> m (f i)
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for (f i -> f i -> f (i, i)
forall (f :: Type -> Type) a b.
Representable f =>
f a -> f b -> f (a, b)
mzipRep f i
x' f i
y') (((i, i) -> m i) -> m (f i)) -> ((i, i) -> m i) -> m (f i)
forall a b. (a -> b) -> a -> b
$ \(i
i, i
j) ->
                        ClosedPoly i (BaseField c) -> m i
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
ClosedPoly var a -> m var
newAssigned (\i -> x
w -> i -> x
w i
i x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- i -> x
w i
j)
                    (f i
isZeros, f i
_) <- f i -> m (f i, f i)
forall i a w (m :: Type -> Type) (f :: Type -> Type).
(MonadCircuit i a w m, Representable f, Traversable f) =>
f i -> m (f i, f i)
runInvert f i
difference
                    f i -> (i -> m i) -> m (f i)
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for f i
isZeros ((i -> m i) -> m (f i)) -> (i -> m i) -> m (f i)
forall a b. (a -> b) -> a -> b
$ \i
isZ ->
                      ClosedPoly i (BaseField c) -> m i
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
ClosedPoly var a -> m var
newAssigned (\i -> x
w -> x
forall a. MultiplicativeMonoid a => a
one x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- i -> x
w i
isZ)
                )
        in
            (c Par1 -> Bool c) -> f (c Par1) -> Bool c
forall b (t :: Type -> Type) x.
(BoolType b, Foldable t) =>
(x -> b) -> t x -> b
any c Par1 -> Bool c
forall (c :: (Type -> Type) -> Type). c Par1 -> Bool c
Bool (c f -> f (c Par1)
forall (c :: (Type -> Type) -> Type) (f :: Type -> Type).
(Package c, Functor f) =>
c f -> f (c Par1)
unpacked c f
result)

instance (BoolType b, Eq b x) => Eq b (Vector n x) where
    Vector n x
u == :: Vector n x -> Vector n x -> b
== Vector n x
v = (b -> b -> b) -> b -> Vector b -> b
forall a b. (a -> b -> a) -> a -> Vector b -> a
V.foldl b -> b -> b
forall b. BoolType b => b -> b -> b
(&&) b
forall b. BoolType b => b
true ((x -> x -> b) -> Vector x -> Vector x -> Vector b
forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith x -> x -> b
forall b a. Eq b a => a -> a -> b
(==) (Vector n x -> Vector x
forall (size :: Natural) a. Vector size a -> Vector a
toV Vector n x
u) (Vector n x -> Vector x
forall (size :: Natural) a. Vector size a -> Vector a
toV Vector n x
v))
    Vector n x
u /= :: Vector n x -> Vector n x -> b
/= Vector n x
v = (b -> b -> b) -> b -> Vector b -> b
forall a b. (a -> b -> a) -> a -> Vector b -> a
V.foldl b -> b -> b
forall b. BoolType b => b -> b -> b
(||) b
forall b. BoolType b => b
false ((x -> x -> b) -> Vector x -> Vector x -> Vector b
forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith x -> x -> b
forall b a. Eq b a => a -> a -> b
(/=) (Vector n x -> Vector x
forall (size :: Natural) a. Vector size a -> Vector a
toV Vector n x
u) (Vector n x -> Vector x
forall (size :: Natural) a. Vector size a -> Vector a
toV Vector n x
v))

deriving newtype instance Symbolic c => Eq (Bool c) (Bool c)

instance (BoolType b, Eq b x0, Eq b x1) => Eq b (x0,x1)
instance (BoolType b, Eq b x0, Eq b x1, Eq b x2) => Eq b (x0,x1,x2)
instance (BoolType b, Eq b x0, Eq b x1, Eq b x2, Eq b x3) => Eq b (x0,x1,x2,x3)

instance (BoolType bool, Eq bool field) => Eq bool (Ext2 field i)
instance (BoolType bool, Eq bool field) => Eq bool (Ext3 field i)

class GEq b u where
    geq :: u x -> u x -> b
    gneq :: u x -> u x -> b

instance (BoolType b, GEq b u, GEq b v) => GEq b (u G.:*: v) where
    geq :: forall (x :: k). (:*:) u v x -> (:*:) u v x -> b
geq (u x
x0 G.:*: v x
x1) (u x
y0 G.:*: v x
y1) = u x -> u x -> b
forall (x :: k). u x -> u x -> b
forall {k} b (u :: k -> Type) (x :: k). GEq b u => u x -> u x -> b
geq u x
x0 u x
y0 b -> b -> b
forall b. BoolType b => b -> b -> b
&& v x -> v x -> b
forall (x :: k). v x -> v x -> b
forall {k} b (u :: k -> Type) (x :: k). GEq b u => u x -> u x -> b
geq v x
x1 v x
y1
    gneq :: forall (x :: k). (:*:) u v x -> (:*:) u v x -> b
gneq (u x
x0 G.:*: v x
x1) (u x
y0 G.:*: v x
y1) = u x -> u x -> b
forall (x :: k). u x -> u x -> b
forall {k} b (u :: k -> Type) (x :: k). GEq b u => u x -> u x -> b
gneq u x
x0 u x
y0 b -> b -> b
forall b. BoolType b => b -> b -> b
|| v x -> v x -> b
forall (x :: k). v x -> v x -> b
forall {k} b (u :: k -> Type) (x :: k). GEq b u => u x -> u x -> b
gneq v x
x1 v x
y1

instance GEq b v => GEq b (G.M1 i c v) where
    geq :: forall (x :: k). M1 i c v x -> M1 i c v x -> b
geq (G.M1 v x
x) (G.M1 v x
y) = v x -> v x -> b
forall (x :: k). v x -> v x -> b
forall {k} b (u :: k -> Type) (x :: k). GEq b u => u x -> u x -> b
geq v x
x v x
y
    gneq :: forall (x :: k). M1 i c v x -> M1 i c v x -> b
gneq (G.M1 v x
x) (G.M1 v x
y) = v x -> v x -> b
forall (x :: k). v x -> v x -> b
forall {k} b (u :: k -> Type) (x :: k). GEq b u => u x -> u x -> b
gneq v x
x v x
y

instance Eq b x => GEq b (G.Rec0 x) where
    geq :: forall (x :: k). Rec0 x x -> Rec0 x x -> b
geq (G.K1 x
x) (G.K1 x
y) = x
x x -> x -> b
forall b a. Eq b a => a -> a -> b
== x
y
    gneq :: forall (x :: k). Rec0 x x -> Rec0 x x -> b
gneq (G.K1 x
x) (G.K1 x
y) = x
x x -> x -> b
forall b a. Eq b a => a -> a -> b
/= x
y