{-# LANGUAGE TypeOperators #-}

module ZkFold.Symbolic.Class where

import           Control.DeepSeq                  (NFData)
import           Control.Monad
import           Data.Eq                          (Eq)
import           Data.Foldable                    (Foldable)
import           Data.Function                    ((.))
import           Data.Functor                     ((<$>))
import           Data.Kind                        (Type)
import           Data.Ord                         (Ord)
import           Data.Type.Equality               (type (~))
import           GHC.Generics                     (type (:.:) (unComp1))
import           Numeric.Natural                  (Natural)

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Base.Control.HApplicative (HApplicative (hpair, hunit))
import           ZkFold.Base.Data.Package         (Package (pack))
import           ZkFold.Base.Data.Product         (uncurryP)
import           ZkFold.Symbolic.MonadCircuit

-- | Field of residues with decidable equality and ordering
-- is called an ``arithmetic'' field.
type Arithmetic a = (ResidueField Natural a, Eq a, Ord a, NFData a)

-- | A type of mappings between functors inside a circuit.
-- @fs@ are input functors, @g@ is an output functor, @c@ is context.
--
-- A function is a mapping between functors inside a circuit if,
-- given an arbitrary builder of circuits @m@ over @c@ with arbitrary @i@ as
-- variables, it maps @f@ many inputs to @g@ many outputs using @m@.
--
-- NOTE: the property above is correct by construction for each function of a
-- suitable type, you don't have to check it yourself.
type CircuitFun (fs :: [Type -> Type]) (g :: Type -> Type) (c :: (Type -> Type) -> Type) =
  forall i m. (NFData i, MonadCircuit i (BaseField c) (WitnessField c) m) => FunBody fs g i m

type family FunBody (fs :: [Type -> Type]) (g :: Type -> Type) (i :: Type) (m :: Type -> Type) where
  FunBody '[] g i m = m (g i)
  FunBody (f ': fs) g i m = f i -> FunBody fs g i m

-- | A Symbolic DSL for performant pure computations with arithmetic circuits.
-- @c@ is a generic context in which computations are performed.
class ( HApplicative c, Package c, Arithmetic (BaseField c)
      , ResidueField (Const (WitnessField c)) (WitnessField c)
      ) => Symbolic c where
    -- | Base algebraic field over which computations are performed.
    type BaseField c :: Type
    -- | Type of witnesses usable inside circuit construction
    type WitnessField c :: Type

    -- | Computes witnesses (exact value may depend on the input to context).
    witnessF :: Functor f => c f -> f (WitnessField c)

    -- | To perform computations in a generic context @c@ -- that is, to form a
    -- mapping between @c f@ and @c g@ for given @f@ and @g@ -- you need to
    -- provide an algorithm for turning @f@ into @g@ inside a circuit.
    fromCircuitF :: c f -> CircuitFun '[f] g c -> c g

    -- | If there is a simpler implementation of a function in pure context,
    -- you can provide it via 'sanityF' to use it in pure contexts.
    sanityF :: BaseField c ~ a => c f -> (f a -> g a) -> (c f -> c g) -> c g
    sanityF c f
x f a -> g a
_ c f -> c g
f = c f -> c g
f c f
x

-- | Embeds the pure value(s) into generic context @c@.
embed :: (Symbolic c, Functor f) => f (BaseField c) -> c f
embed :: forall (c :: (Type -> Type) -> Type) (f :: Type -> Type).
(Symbolic c, Functor f) =>
f (BaseField c) -> c f
embed f (BaseField c)
cs = c U1 -> CircuitFun '[U1] f c -> c f
forall (f :: Type -> Type) (g :: Type -> Type).
c f -> CircuitFun '[f] g c -> c g
forall (c :: (Type -> Type) -> Type) (f :: Type -> Type)
       (g :: Type -> Type).
Symbolic c =>
c f -> CircuitFun '[f] g c -> c g
fromCircuitF c U1
forall {k} (c :: (k -> Type) -> Type). HApplicative c => c U1
hunit (\U1 i
_ -> f i -> m (f i)
forall a. a -> m a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (BaseField c -> i
forall a b. FromConstant a b => a -> b
fromConstant (BaseField c -> i) -> f (BaseField c) -> f i
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> f (BaseField c)
cs))

symbolicF ::
  (Symbolic c, BaseField c ~ a) => c f ->
  (f a -> g a) -> CircuitFun '[f] g c -> c g
symbolicF :: forall (c :: (Type -> Type) -> Type) a (f :: Type -> Type)
       (g :: Type -> Type).
(Symbolic c, BaseField c ~ a) =>
c f -> (f a -> g a) -> CircuitFun '[f] g c -> c g
symbolicF c f
x f a -> g a
f CircuitFun '[f] g c
c = c f -> (f a -> g a) -> (c f -> c g) -> c g
forall a (f :: Type -> Type) (g :: Type -> Type).
(BaseField c ~ a) =>
c f -> (f a -> g a) -> (c f -> c g) -> c g
forall (c :: (Type -> Type) -> Type) a (f :: Type -> Type)
       (g :: Type -> Type).
(Symbolic c, BaseField c ~ a) =>
c f -> (f a -> g a) -> (c f -> c g) -> c g
sanityF c f
x f a -> g a
f (c f -> CircuitFun '[f] g c -> c g
forall (f :: Type -> Type) (g :: Type -> Type).
c f -> CircuitFun '[f] g c -> c g
forall (c :: (Type -> Type) -> Type) (f :: Type -> Type)
       (g :: Type -> Type).
Symbolic c =>
c f -> CircuitFun '[f] g c -> c g
`fromCircuitF` FunBody '[f] g i m
CircuitFun '[f] g c
c)

symbolic2F ::
  (Symbolic c, BaseField c ~ a) => c f -> c g ->
  (f a -> g a -> h a) -> CircuitFun '[f, g] h c -> c h
-- | Runs the binary function from @f@ and @g@ into @h@ in a generic context @c@.
symbolic2F :: forall (c :: (Type -> Type) -> Type) a (f :: Type -> Type)
       (g :: Type -> Type) (h :: Type -> Type).
(Symbolic c, BaseField c ~ a) =>
c f -> c g -> (f a -> g a -> h a) -> CircuitFun '[f, g] h c -> c h
symbolic2F c f
x c g
y f a -> g a -> h a
f CircuitFun '[f, g] h c
m = c (f :*: g)
-> ((:*:) f g a -> h a) -> CircuitFun '[f :*: g] h c -> c h
forall (c :: (Type -> Type) -> Type) a (f :: Type -> Type)
       (g :: Type -> Type).
(Symbolic c, BaseField c ~ a) =>
c f -> (f a -> g a) -> CircuitFun '[f] g c -> c g
symbolicF (c f -> c g -> c (f :*: g)
forall {k} (c :: (k -> Type) -> Type) (f :: k -> Type)
       (g :: k -> Type).
HApplicative c =>
c f -> c g -> c (f :*: g)
forall (f :: Type -> Type) (g :: Type -> Type).
c f -> c g -> c (f :*: g)
hpair c f
x c g
y) ((f a -> g a -> h a) -> (:*:) f g a -> h a
forall {k} (f :: k -> Type) (a :: k) (g :: k -> Type) b.
(f a -> g a -> b) -> (:*:) f g a -> b
uncurryP f a -> g a -> h a
f) ((f i -> g i -> m (h i)) -> (:*:) f g i -> m (h i)
forall {k} (f :: k -> Type) (a :: k) (g :: k -> Type) b.
(f a -> g a -> b) -> (:*:) f g a -> b
uncurryP FunBody '[f, g] h i m
f i -> g i -> m (h i)
CircuitFun '[f, g] h c
m)

fromCircuit2F :: Symbolic c => c f -> c g -> CircuitFun '[f, g] h c -> c h
-- | Runs the binary @'CircuitFun'@ in a generic context.
fromCircuit2F :: forall (c :: (Type -> Type) -> Type) (f :: Type -> Type)
       (g :: Type -> Type) (h :: Type -> Type).
Symbolic c =>
c f -> c g -> CircuitFun '[f, g] h c -> c h
fromCircuit2F c f
x c g
y CircuitFun '[f, g] h c
m = c (f :*: g) -> CircuitFun '[f :*: g] h c -> c h
forall (f :: Type -> Type) (g :: Type -> Type).
c f -> CircuitFun '[f] g c -> c g
forall (c :: (Type -> Type) -> Type) (f :: Type -> Type)
       (g :: Type -> Type).
Symbolic c =>
c f -> CircuitFun '[f] g c -> c g
fromCircuitF (c f -> c g -> c (f :*: g)
forall {k} (c :: (k -> Type) -> Type) (f :: k -> Type)
       (g :: k -> Type).
HApplicative c =>
c f -> c g -> c (f :*: g)
forall (f :: Type -> Type) (g :: Type -> Type).
c f -> c g -> c (f :*: g)
hpair c f
x c g
y) ((f i -> g i -> m (h i)) -> (:*:) f g i -> m (h i)
forall {k} (f :: k -> Type) (a :: k) (g :: k -> Type) b.
(f a -> g a -> b) -> (:*:) f g a -> b
uncurryP FunBody '[f, g] h i m
f i -> g i -> m (h i)
CircuitFun '[f, g] h c
m)

symbolic3F ::
  (Symbolic c, BaseField c ~ a) => c f -> c g -> c h ->
  (f a -> g a -> h a -> k a) -> CircuitFun '[f, g, h] k c -> c k
-- | Runs the ternary function from @f@, @g@ and @h@ into @k@ in a context @c@.
symbolic3F :: forall (c :: (Type -> Type) -> Type) a (f :: Type -> Type)
       (g :: Type -> Type) (h :: Type -> Type) (k :: Type -> Type).
(Symbolic c, BaseField c ~ a) =>
c f
-> c g
-> c h
-> (f a -> g a -> h a -> k a)
-> CircuitFun '[f, g, h] k c
-> c k
symbolic3F c f
x c g
y c h
z f a -> g a -> h a -> k a
f CircuitFun '[f, g, h] k c
m = c (f :*: g)
-> c h
-> ((:*:) f g a -> h a -> k a)
-> CircuitFun '[f :*: g, h] k c
-> c k
forall (c :: (Type -> Type) -> Type) a (f :: Type -> Type)
       (g :: Type -> Type) (h :: Type -> Type).
(Symbolic c, BaseField c ~ a) =>
c f -> c g -> (f a -> g a -> h a) -> CircuitFun '[f, g] h c -> c h
symbolic2F (c f -> c g -> c (f :*: g)
forall {k} (c :: (k -> Type) -> Type) (f :: k -> Type)
       (g :: k -> Type).
HApplicative c =>
c f -> c g -> c (f :*: g)
forall (f :: Type -> Type) (g :: Type -> Type).
c f -> c g -> c (f :*: g)
hpair c f
x c g
y) c h
z ((f a -> g a -> h a -> k a) -> (:*:) f g a -> h a -> k a
forall {k} (f :: k -> Type) (a :: k) (g :: k -> Type) b.
(f a -> g a -> b) -> (:*:) f g a -> b
uncurryP f a -> g a -> h a -> k a
f) ((f i -> g i -> h i -> m (k i)) -> (:*:) f g i -> h i -> m (k i)
forall {k} (f :: k -> Type) (a :: k) (g :: k -> Type) b.
(f a -> g a -> b) -> (:*:) f g a -> b
uncurryP FunBody '[f, g, h] k i m
f i -> g i -> h i -> m (k i)
CircuitFun '[f, g, h] k c
m)

fromCircuit3F ::
  Symbolic c => c f -> c g -> c h -> CircuitFun '[f, g, h] k c -> c k
-- | Runs the ternary @'CircuitFun'@ in a generic context.
fromCircuit3F :: forall (c :: (Type -> Type) -> Type) (f :: Type -> Type)
       (g :: Type -> Type) (h :: Type -> Type) (k :: Type -> Type).
Symbolic c =>
c f -> c g -> c h -> CircuitFun '[f, g, h] k c -> c k
fromCircuit3F c f
x c g
y c h
z CircuitFun '[f, g, h] k c
m = c (f :*: g) -> c h -> CircuitFun '[f :*: g, h] k c -> c k
forall (c :: (Type -> Type) -> Type) (f :: Type -> Type)
       (g :: Type -> Type) (h :: Type -> Type).
Symbolic c =>
c f -> c g -> CircuitFun '[f, g] h c -> c h
fromCircuit2F (c f -> c g -> c (f :*: g)
forall {k} (c :: (k -> Type) -> Type) (f :: k -> Type)
       (g :: k -> Type).
HApplicative c =>
c f -> c g -> c (f :*: g)
forall (f :: Type -> Type) (g :: Type -> Type).
c f -> c g -> c (f :*: g)
hpair c f
x c g
y) c h
z ((f i -> g i -> h i -> m (k i)) -> (:*:) f g i -> h i -> m (k i)
forall {k} (f :: k -> Type) (a :: k) (g :: k -> Type) b.
(f a -> g a -> b) -> (:*:) f g a -> b
uncurryP FunBody '[f, g, h] k i m
f i -> g i -> h i -> m (k i)
CircuitFun '[f, g, h] k c
m)

symbolicVF ::
  (Symbolic c, BaseField c ~ a, WitnessField c ~ w, Foldable f, Functor f) =>
  f (c g) -> (f (g a) -> h a) ->
  (forall i m. MonadCircuit i a w m => f (g i) -> m (h i)) -> c h
-- | Given a generic context @c@, runs the function from @f@ many @c g@'s into @c h@.
symbolicVF :: forall (c :: (Type -> Type) -> Type) a w (f :: Type -> Type)
       (g :: Type -> Type) (h :: Type -> Type).
(Symbolic c, BaseField c ~ a, WitnessField c ~ w, Foldable f,
 Functor f) =>
f (c g)
-> (f (g a) -> h a)
-> (forall i (m :: Type -> Type).
    MonadCircuit i a w m =>
    f (g i) -> m (h i))
-> c h
symbolicVF f (c g)
xs f (g a) -> h a
f forall i (m :: Type -> Type).
MonadCircuit i a w m =>
f (g i) -> m (h i)
m = c (f :.: g)
-> ((:.:) f g a -> h a) -> CircuitFun '[f :.: g] h c -> c h
forall (c :: (Type -> Type) -> Type) a (f :: Type -> Type)
       (g :: Type -> Type).
(Symbolic c, BaseField c ~ a) =>
c f -> (f a -> g a) -> CircuitFun '[f] g c -> c g
symbolicF (f (c g) -> c (f :.: g)
forall {k1} (c :: (k1 -> Type) -> Type) (f :: Type -> Type)
       (g :: k1 -> Type).
(Package c, Foldable f, Functor f) =>
f (c g) -> c (f :.: g)
forall (f :: Type -> Type) (g :: Type -> Type).
(Foldable f, Functor f) =>
f (c g) -> c (f :.: g)
pack f (c g)
xs) (f (g a) -> h a
f (f (g a) -> h a) -> ((:.:) f g a -> f (g a)) -> (:.:) f g a -> h a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (:.:) f g a -> f (g a)
forall k2 k1 (f :: k2 -> Type) (g :: k1 -> k2) (p :: k1).
(:.:) f g p -> f (g p)
unComp1) (f (g i) -> m (h i)
forall i (m :: Type -> Type).
MonadCircuit i a w m =>
f (g i) -> m (h i)
m (f (g i) -> m (h i))
-> ((:.:) f g i -> f (g i)) -> (:.:) f g i -> m (h i)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (:.:) f g i -> f (g i)
forall k2 k1 (f :: k2 -> Type) (g :: k1 -> k2) (p :: k1).
(:.:) f g p -> f (g p)
unComp1)

fromCircuitVF ::
  (Symbolic c, BaseField c ~ a, WitnessField c ~ w, Foldable f, Functor f) =>
  f (c g) -> (forall i m. MonadCircuit i a w m => f (g i) -> m (h i)) -> c h
-- | Given a generic context @c@, runs the @'CircuitFun'@ from @f@ many @c g@'s into @c h@.
fromCircuitVF :: forall (c :: (Type -> Type) -> Type) a w (f :: Type -> Type)
       (g :: Type -> Type) (h :: Type -> Type).
(Symbolic c, BaseField c ~ a, WitnessField c ~ w, Foldable f,
 Functor f) =>
f (c g)
-> (forall i (m :: Type -> Type).
    MonadCircuit i a w m =>
    f (g i) -> m (h i))
-> c h
fromCircuitVF f (c g)
xs forall i (m :: Type -> Type).
MonadCircuit i a w m =>
f (g i) -> m (h i)
m = c (f :.: g) -> CircuitFun '[f :.: g] h c -> c h
forall (f :: Type -> Type) (g :: Type -> Type).
c f -> CircuitFun '[f] g c -> c g
forall (c :: (Type -> Type) -> Type) (f :: Type -> Type)
       (g :: Type -> Type).
Symbolic c =>
c f -> CircuitFun '[f] g c -> c g
fromCircuitF (f (c g) -> c (f :.: g)
forall {k1} (c :: (k1 -> Type) -> Type) (f :: Type -> Type)
       (g :: k1 -> Type).
(Package c, Foldable f, Functor f) =>
f (c g) -> c (f :.: g)
forall (f :: Type -> Type) (g :: Type -> Type).
(Foldable f, Functor f) =>
f (c g) -> c (f :.: g)
pack f (c g)
xs) (f (g i) -> m (h i)
forall i (m :: Type -> Type).
MonadCircuit i a w m =>
f (g i) -> m (h i)
m (f (g i) -> m (h i))
-> ((:.:) f g i -> f (g i)) -> (:.:) f g i -> m (h i)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (:.:) f g i -> f (g i)
forall k2 k1 (f :: k2 -> Type) (g :: k1 -> k2) (p :: k1).
(:.:) f g p -> f (g p)
unComp1)