{-# LANGUAGE TypeOperators        #-}
{-# LANGUAGE UndecidableInstances #-}

module ZkFold.Symbolic.Data.Input (
    SymbolicInput (..)
) where

import           Data.Type.Equality               (type (~))
import           Data.Typeable                    (Proxy (..))
import qualified GHC.Generics                     as G
import           GHC.TypeLits                     (KnownNat)
import           Prelude                          (($), (.))

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Base.Data.Vector          (Vector)
import           ZkFold.Symbolic.Class
import           ZkFold.Symbolic.Data.Bool
import           ZkFold.Symbolic.Data.Class
import           ZkFold.Symbolic.Data.Combinators
import           ZkFold.Symbolic.MonadCircuit

-- | A class for Symbolic input.
class SymbolicOutput d => SymbolicInput d where
    isValid :: d -> Bool (Context d)
    default isValid ::
      (G.Generic d, GSymbolicInput (G.Rep d), GContext (G.Rep d) ~ Context d)
      => d -> Bool (Context d)
    isValid = forall {k} (u :: k -> Type) (x :: k).
GSymbolicInput u =>
u x -> Bool (GContext u)
forall (u :: Type -> Type) x.
GSymbolicInput u =>
u x -> Bool (GContext u)
gisValid @(G.Rep d) (Rep d Any -> Bool (Context d))
-> (d -> Rep d Any) -> d -> Bool (Context d)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. d -> Rep d Any
forall x. d -> Rep d x
forall a x. Generic a => a -> Rep a x
G.from

instance Symbolic c => SymbolicInput (Bool c) where
  isValid :: Bool c -> Bool (Context (Bool c))
isValid (Bool c Par1
b) = Context (Bool c) Par1 -> Bool (Context (Bool c))
forall (c :: (Type -> Type) -> Type). c Par1 -> Bool c
Bool (Context (Bool c) Par1 -> Bool (Context (Bool c)))
-> Context (Bool c) Par1 -> Bool (Context (Bool c))
forall a b. (a -> b) -> a -> b
$ Context (Bool c) Par1
-> (forall {i} {m :: Type -> Type}.
    (NFData i,
     MonadCircuit
       i
       (BaseField (Context (Bool c)))
       (WitnessField (Context (Bool c)))
       m) =>
    FunBody '[Par1] Par1 i m)
-> Context (Bool c) Par1
forall (f :: Type -> Type) (g :: Type -> Type).
Context (Bool c) f
-> CircuitFun '[f] g (Context (Bool c)) -> Context (Bool c) g
forall (c :: (Type -> Type) -> Type) (f :: Type -> Type)
       (g :: Type -> Type).
Symbolic c =>
c f -> CircuitFun '[f] g c -> c g
fromCircuitF c Par1
Context (Bool c) Par1
b ((forall {i} {m :: Type -> Type}.
  (NFData i,
   MonadCircuit
     i
     (BaseField (Context (Bool c)))
     (WitnessField (Context (Bool c)))
     m) =>
  FunBody '[Par1] Par1 i m)
 -> Context (Bool c) Par1)
-> (forall {i} {m :: Type -> Type}.
    (NFData i,
     MonadCircuit
       i
       (BaseField (Context (Bool c)))
       (WitnessField (Context (Bool c)))
       m) =>
    FunBody '[Par1] Par1 i m)
-> Context (Bool c) Par1
forall a b. (a -> b) -> a -> b
$
      \(G.Par1 i
v) -> do
        i
u <- ClosedPoly i (BaseField c) -> m i
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
ClosedPoly var a -> m var
newAssigned (\i -> x
x -> i -> x
x i
v x -> x -> x
forall a. MultiplicativeSemigroup a => a -> a -> a
* (x
forall a. MultiplicativeMonoid a => a
one x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- i -> x
x i
v))
        Par1 i -> m (Par1 i)
forall i a w (m :: Type -> Type) (f :: Type -> Type).
(MonadCircuit i a w m, Representable f, Traversable f) =>
f i -> m (f i)
isZero (Par1 i -> m (Par1 i)) -> Par1 i -> m (Par1 i)
forall a b. (a -> b) -> a -> b
$ i -> Par1 i
forall p. p -> Par1 p
G.Par1 i
u

instance (Symbolic c, LayoutFunctor f) => SymbolicInput (c f) where
  isValid :: c f -> Bool (Context (c f))
isValid c f
_ = Bool c
Bool (Context (c f))
forall b. BoolType b => b
true

instance Symbolic c => SymbolicInput (Proxy c) where
  isValid :: Proxy c -> Bool (Context (Proxy c))
isValid Proxy c
_ = Bool c
Bool (Context (Proxy c))
forall b. BoolType b => b
true

instance (
    Context x ~ Context y
    , SymbolicInput x
    , SymbolicInput y
    ) => SymbolicInput (x, y)

instance (
    Context x ~ Context y
    , Context y ~ Context z
    , SymbolicInput x
    , SymbolicInput y
    , SymbolicInput z
    ) => SymbolicInput (x, y, z)

instance (KnownNat n, SymbolicInput x) => SymbolicInput (Vector n x) where
  isValid :: Vector n x -> Bool (Context (Vector n x))
isValid = (x -> Bool (Context x)) -> Vector n x -> Bool (Context x)
forall b (t :: Type -> Type) x.
(BoolType b, Foldable t) =>
(x -> b) -> t x -> b
all x -> Bool (Context x)
forall d. SymbolicInput d => d -> Bool (Context d)
isValid

class GSymbolicData u => GSymbolicInput u where
    gisValid :: u x -> Bool (GContext u)

instance
    ( GContext u ~ GContext v
    , GSupport u ~ GSupport v
    , GSymbolicInput u
    , GSymbolicInput v
    ) => GSymbolicInput (u G.:*: v) where
    gisValid :: forall (x :: k). (:*:) u v x -> Bool (GContext (u :*: v))
gisValid (u x
u G.:*: v x
v) = u x -> Bool (GContext u)
forall (x :: k). u x -> Bool (GContext u)
forall {k} (u :: k -> Type) (x :: k).
GSymbolicInput u =>
u x -> Bool (GContext u)
gisValid u x
u Bool (GContext v) -> Bool (GContext v) -> Bool (GContext v)
forall b. BoolType b => b -> b -> b
&& v x -> Bool (GContext v)
forall (x :: k). v x -> Bool (GContext v)
forall {k} (u :: k -> Type) (x :: k).
GSymbolicInput u =>
u x -> Bool (GContext u)
gisValid v x
v

instance GSymbolicInput u => GSymbolicInput (G.M1 i c u) where
    gisValid :: forall (x :: k). M1 i c u x -> Bool (GContext (M1 i c u))
gisValid (G.M1 u x
u) = u x -> Bool (GContext u)
forall (x :: k). u x -> Bool (GContext u)
forall {k} (u :: k -> Type) (x :: k).
GSymbolicInput u =>
u x -> Bool (GContext u)
gisValid u x
u

instance SymbolicInput x => GSymbolicInput (G.Rec0 x) where
    gisValid :: forall (x :: k). Rec0 x x -> Bool (GContext (Rec0 x))
gisValid (G.K1 x
x) = x -> Bool (Context x)
forall d. SymbolicInput d => d -> Bool (Context d)
isValid x
x