{-# LANGUAGE DerivingStrategies   #-}
{-# LANGUAGE UndecidableInstances #-}

module ZkFold.Symbolic.Data.Bool (
    BoolType(..),
    Bool(..),
    all,
    all1,
    any,
    and,
    or
) where

import           Control.DeepSeq                 (NFData)
import           Data.Eq                         (Eq (..))
import           Data.Foldable                   (Foldable (..))
import           Data.Function                   (($), (.))
import           Data.Functor                    (Functor, fmap, (<$>))
import           GHC.Generics                    (Generic, Par1 (..))
import qualified Prelude                         as Haskell
import           Text.Show                       (Show)

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Symbolic.Class
import           ZkFold.Symbolic.Interpreter     (Interpreter (..))
import           ZkFold.Symbolic.MonadCircuit    (newAssigned)

class BoolType b where
    true  :: b

    false :: b

    not   :: b -> b

    infixr 3 &&
    (&&)  :: b -> b -> b

    infixr 2 ||
    (||)  :: b -> b -> b

    xor  :: b -> b -> b

instance BoolType Haskell.Bool where
    true :: Bool
true  = Bool
Haskell.True

    false :: Bool
false = Bool
Haskell.False

    not :: Bool -> Bool
not   = Bool -> Bool
Haskell.not

    && :: Bool -> Bool -> Bool
(&&)  = Bool -> Bool -> Bool
(Haskell.&&)

    || :: Bool -> Bool -> Bool
(||)  = Bool -> Bool -> Bool
(Haskell.||)

    xor :: Bool -> Bool -> Bool
xor = Bool -> Bool -> Bool
forall b. BoolType b => b -> b -> b
xor

-- TODO (Issue #18): hide this constructor
newtype Bool c = Bool (c Par1)
    deriving ((forall x. Bool c -> Rep (Bool c) x)
-> (forall x. Rep (Bool c) x -> Bool c) -> Generic (Bool c)
forall x. Rep (Bool c) x -> Bool c
forall x. Bool c -> Rep (Bool c) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (c :: (Type -> Type) -> Type) x. Rep (Bool c) x -> Bool c
forall (c :: (Type -> Type) -> Type) x. Bool c -> Rep (Bool c) x
$cfrom :: forall (c :: (Type -> Type) -> Type) x. Bool c -> Rep (Bool c) x
from :: forall x. Bool c -> Rep (Bool c) x
$cto :: forall (c :: (Type -> Type) -> Type) x. Rep (Bool c) x -> Bool c
to :: forall x. Rep (Bool c) x -> Bool c
Generic)

deriving instance NFData (c Par1) => NFData (Bool c)
deriving instance Eq (c Par1) => Eq (Bool c)
deriving instance Show (c Par1) => Show (Bool c)

instance {-# OVERLAPPING #-} (Eq a, MultiplicativeMonoid a) => Show (Bool (Interpreter a)) where
    show :: Bool (Interpreter a) -> String
show (Bool (Interpreter a) -> a
forall a. Bool (Interpreter a) -> a
fromBool -> a
x) = if a
x a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
forall a. MultiplicativeMonoid a => a
one then String
"True" else String
"False"

instance Symbolic c => BoolType (Bool c) where
    true :: Bool c
true = c Par1 -> Bool c
forall (c :: (Type -> Type) -> Type). c Par1 -> Bool c
Bool (c Par1 -> Bool c) -> c Par1 -> Bool c
forall a b. (a -> b) -> a -> b
$ Par1 (BaseField c) -> c Par1
forall (c :: (Type -> Type) -> Type) (f :: Type -> Type).
(Symbolic c, Functor f) =>
f (BaseField c) -> c f
embed (BaseField c -> Par1 (BaseField c)
forall p. p -> Par1 p
Par1 BaseField c
forall a. MultiplicativeMonoid a => a
one)

    false :: Bool c
false = c Par1 -> Bool c
forall (c :: (Type -> Type) -> Type). c Par1 -> Bool c
Bool (c Par1 -> Bool c) -> c Par1 -> Bool c
forall a b. (a -> b) -> a -> b
$ Par1 (BaseField c) -> c Par1
forall (c :: (Type -> Type) -> Type) (f :: Type -> Type).
(Symbolic c, Functor f) =>
f (BaseField c) -> c f
embed (BaseField c -> Par1 (BaseField c)
forall p. p -> Par1 p
Par1 BaseField c
forall a. AdditiveMonoid a => a
zero)

    not :: Bool c -> Bool c
not (Bool c Par1
b) = c Par1 -> Bool c
forall (c :: (Type -> Type) -> Type). c Par1 -> Bool c
Bool (c Par1 -> Bool c) -> c Par1 -> Bool c
forall a b. (a -> b) -> a -> b
$ c Par1
-> (forall {i} {m :: Type -> Type}.
    (NFData i, MonadCircuit i (BaseField c) (WitnessField c) m) =>
    FunBody '[Par1] Par1 i m)
-> c Par1
forall (f :: Type -> Type) (g :: Type -> Type).
c f -> CircuitFun '[f] g c -> c g
forall (c :: (Type -> Type) -> Type) (f :: Type -> Type)
       (g :: Type -> Type).
Symbolic c =>
c f -> CircuitFun '[f] g c -> c g
fromCircuitF c Par1
b ((forall {i} {m :: Type -> Type}.
  (NFData i, MonadCircuit i (BaseField c) (WitnessField c) m) =>
  FunBody '[Par1] Par1 i m)
 -> c Par1)
-> (forall {i} {m :: Type -> Type}.
    (NFData i, MonadCircuit i (BaseField c) (WitnessField c) m) =>
    FunBody '[Par1] Par1 i m)
-> c Par1
forall a b. (a -> b) -> a -> b
$
      \(Par1 i
v) -> i -> Par1 i
forall p. p -> Par1 p
Par1 (i -> Par1 i) -> m i -> m (Par1 i)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> ClosedPoly i (BaseField c) -> m i
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
ClosedPoly var a -> m var
newAssigned ((i -> x) -> x
forall a. MultiplicativeMonoid a => a
one ((i -> x) -> x) -> ((i -> x) -> x) -> (i -> x) -> x
forall a. AdditiveGroup a => a -> a -> a
- ((i -> x) -> i -> x
forall a b. (a -> b) -> a -> b
$ i
v))

    Bool c Par1
b1 && :: Bool c -> Bool c -> Bool c
&& Bool c Par1
b2 = c Par1 -> Bool c
forall (c :: (Type -> Type) -> Type). c Par1 -> Bool c
Bool (c Par1 -> Bool c) -> c Par1 -> Bool c
forall a b. (a -> b) -> a -> b
$ c Par1
-> c Par1
-> (forall {i} {m :: Type -> Type}.
    (NFData i, MonadCircuit i (BaseField c) (WitnessField c) m) =>
    FunBody '[Par1, Par1] Par1 i m)
-> c Par1
forall (c :: (Type -> Type) -> Type) (f :: Type -> Type)
       (g :: Type -> Type) (h :: Type -> Type).
Symbolic c =>
c f -> c g -> CircuitFun '[f, g] h c -> c h
fromCircuit2F c Par1
b1 c Par1
b2 ((forall {i} {m :: Type -> Type}.
  (NFData i, MonadCircuit i (BaseField c) (WitnessField c) m) =>
  FunBody '[Par1, Par1] Par1 i m)
 -> c Par1)
-> (forall {i} {m :: Type -> Type}.
    (NFData i, MonadCircuit i (BaseField c) (WitnessField c) m) =>
    FunBody '[Par1, Par1] Par1 i m)
-> c Par1
forall a b. (a -> b) -> a -> b
$
      \(Par1 i
v1) (Par1 i
v2) -> i -> Par1 i
forall p. p -> Par1 p
Par1 (i -> Par1 i) -> m i -> m (Par1 i)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> ClosedPoly i (BaseField c) -> m i
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
ClosedPoly var a -> m var
newAssigned (((i -> x) -> i -> x
forall a b. (a -> b) -> a -> b
$ i
v1) ((i -> x) -> x) -> ((i -> x) -> x) -> (i -> x) -> x
forall a. MultiplicativeSemigroup a => a -> a -> a
* ((i -> x) -> i -> x
forall a b. (a -> b) -> a -> b
$ i
v2))

    Bool c Par1
b1 || :: Bool c -> Bool c -> Bool c
|| Bool c Par1
b2 = c Par1 -> Bool c
forall (c :: (Type -> Type) -> Type). c Par1 -> Bool c
Bool (c Par1 -> Bool c) -> c Par1 -> Bool c
forall a b. (a -> b) -> a -> b
$ c Par1
-> c Par1
-> (forall {i} {m :: Type -> Type}.
    (NFData i, MonadCircuit i (BaseField c) (WitnessField c) m) =>
    FunBody '[Par1, Par1] Par1 i m)
-> c Par1
forall (c :: (Type -> Type) -> Type) (f :: Type -> Type)
       (g :: Type -> Type) (h :: Type -> Type).
Symbolic c =>
c f -> c g -> CircuitFun '[f, g] h c -> c h
fromCircuit2F c Par1
b1 c Par1
b2 ((forall {i} {m :: Type -> Type}.
  (NFData i, MonadCircuit i (BaseField c) (WitnessField c) m) =>
  FunBody '[Par1, Par1] Par1 i m)
 -> c Par1)
-> (forall {i} {m :: Type -> Type}.
    (NFData i, MonadCircuit i (BaseField c) (WitnessField c) m) =>
    FunBody '[Par1, Par1] Par1 i m)
-> c Par1
forall a b. (a -> b) -> a -> b
$
      \(Par1 i
v1) (Par1 i
v2) -> i -> Par1 i
forall p. p -> Par1 p
Par1 (i -> Par1 i) -> m i -> m (Par1 i)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$>
          ClosedPoly i (BaseField c) -> m i
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
ClosedPoly var a -> m var
newAssigned (\i -> x
x -> let x1 :: x
x1 = i -> x
x i
v1; x2 :: x
x2 = i -> x
x i
v2 in x
x1 x -> x -> x
forall a. AdditiveSemigroup a => a -> a -> a
+ x
x2 x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- x
x1 x -> x -> x
forall a. MultiplicativeSemigroup a => a -> a -> a
* x
x2)

    Bool c Par1
b1 xor :: Bool c -> Bool c -> Bool c
`xor` Bool c Par1
b2 = c Par1 -> Bool c
forall (c :: (Type -> Type) -> Type). c Par1 -> Bool c
Bool (c Par1 -> Bool c) -> c Par1 -> Bool c
forall a b. (a -> b) -> a -> b
$ c Par1
-> c Par1
-> (forall {i} {m :: Type -> Type}.
    (NFData i, MonadCircuit i (BaseField c) (WitnessField c) m) =>
    FunBody '[Par1, Par1] Par1 i m)
-> c Par1
forall (c :: (Type -> Type) -> Type) (f :: Type -> Type)
       (g :: Type -> Type) (h :: Type -> Type).
Symbolic c =>
c f -> c g -> CircuitFun '[f, g] h c -> c h
fromCircuit2F c Par1
b1 c Par1
b2 ((forall {i} {m :: Type -> Type}.
  (NFData i, MonadCircuit i (BaseField c) (WitnessField c) m) =>
  FunBody '[Par1, Par1] Par1 i m)
 -> c Par1)
-> (forall {i} {m :: Type -> Type}.
    (NFData i, MonadCircuit i (BaseField c) (WitnessField c) m) =>
    FunBody '[Par1, Par1] Par1 i m)
-> c Par1
forall a b. (a -> b) -> a -> b
$
      \(Par1 i
v1) (Par1 i
v2) -> i -> Par1 i
forall p. p -> Par1 p
Par1 (i -> Par1 i) -> m i -> m (Par1 i)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$>
          ClosedPoly i (BaseField c) -> m i
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
ClosedPoly var a -> m var
newAssigned (\i -> x
x -> let x1 :: x
x1 = i -> x
x i
v1; x2 :: x
x2 = i -> x
x i
v2 in x
x1 x -> x -> x
forall a. AdditiveSemigroup a => a -> a -> a
+ x
x2 x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- (x
forall a. MultiplicativeMonoid a => a
one x -> x -> x
forall a. AdditiveSemigroup a => a -> a -> a
+ x
forall a. MultiplicativeMonoid a => a
one) x -> x -> x
forall a. MultiplicativeSemigroup a => a -> a -> a
* x
x1 x -> x -> x
forall a. MultiplicativeSemigroup a => a -> a -> a
* x
x2)

fromBool :: Bool (Interpreter a) -> a
fromBool :: forall a. Bool (Interpreter a) -> a
fromBool (Bool (Interpreter (Par1 a
b))) = a
b

all :: (BoolType b, Foldable t) => (x -> b) -> t x -> b
all :: forall b (t :: Type -> Type) x.
(BoolType b, Foldable t) =>
(x -> b) -> t x -> b
all x -> b
f = (x -> b -> b) -> b -> t x -> b
forall a b. (a -> b -> b) -> b -> t a -> b
forall (t :: Type -> Type) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (b -> b -> b
forall b. BoolType b => b -> b -> b
(&&) (b -> b -> b) -> (x -> b) -> x -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. x -> b
f) b
forall b. BoolType b => b
true

and :: (BoolType b, Foldable t) => t b -> b
and :: forall b (t :: Type -> Type). (BoolType b, Foldable t) => t b -> b
and = (b -> b) -> t b -> b
forall b (t :: Type -> Type) x.
(BoolType b, Foldable t) =>
(x -> b) -> t x -> b
all b -> b
forall a. a -> a
Haskell.id

or :: (BoolType b, Foldable t) => t b -> b
or :: forall b (t :: Type -> Type). (BoolType b, Foldable t) => t b -> b
or = (b -> b) -> t b -> b
forall b (t :: Type -> Type) x.
(BoolType b, Foldable t) =>
(x -> b) -> t x -> b
any b -> b
forall a. a -> a
Haskell.id

all1 :: (BoolType b, Functor t, Foldable t) => (x -> b) -> t x -> b
all1 :: forall b (t :: Type -> Type) x.
(BoolType b, Functor t, Foldable t) =>
(x -> b) -> t x -> b
all1 x -> b
f = (b -> b -> b) -> t b -> b
forall a. (a -> a -> a) -> t a -> a
forall (t :: Type -> Type) a.
Foldable t =>
(a -> a -> a) -> t a -> a
foldr1 b -> b -> b
forall b. BoolType b => b -> b -> b
(&&) (t b -> b) -> (t x -> t b) -> t x -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (x -> b) -> t x -> t b
forall a b. (a -> b) -> t a -> t b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap x -> b
f

any :: (BoolType b, Foldable t) => (x -> b) -> t x -> b
any :: forall b (t :: Type -> Type) x.
(BoolType b, Foldable t) =>
(x -> b) -> t x -> b
any x -> b
f = (x -> b -> b) -> b -> t x -> b
forall a b. (a -> b -> b) -> b -> t a -> b
forall (t :: Type -> Type) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (b -> b -> b
forall b. BoolType b => b -> b -> b
(||) (b -> b -> b) -> (x -> b) -> x -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. x -> b
f) b
forall b. BoolType b => b
false