{-# LANGUAGE RebindableSyntax     #-}
{-# LANGUAGE TypeOperators        #-}
{-# LANGUAGE UndecidableInstances #-}

{-# OPTIONS_GHC -Wno-orphans #-}

module ZkFold.Symbolic.Data.Ed25519 (Ed25519_Point) where

import           Prelude                                   (fromInteger, type (~), ($))
import qualified Prelude                                   as P

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Base.Algebra.Basic.Number
import           ZkFold.Base.Algebra.EllipticCurve.Class
import           ZkFold.Base.Algebra.EllipticCurve.Ed25519 (Ed25519_Base, Ed25519_PointOf)
import           ZkFold.Symbolic.Class                     (Symbolic (..))
import           ZkFold.Symbolic.Data.Bool
import           ZkFold.Symbolic.Data.ByteString
import           ZkFold.Symbolic.Data.Combinators          (RegisterSize (Auto))
import           ZkFold.Symbolic.Data.Conditional
import           ZkFold.Symbolic.Data.FFA
import           ZkFold.Symbolic.Data.FieldElement

type Ed25519_Point ctx = Ed25519_PointOf (FFA Ed25519_Base 'Auto ctx)

instance (Symbolic ctx, KnownFFA Ed25519_Base 'Auto ctx) => CyclicGroup (Ed25519_Point ctx) where
  type ScalarFieldOf (Ed25519_Point ctx) = FieldElement ctx
  pointGen :: Ed25519_Point ctx
pointGen = FFA Ed25519_Base 'Auto ctx
-> FFA Ed25519_Base 'Auto ctx -> Ed25519_Point ctx
forall field point. Planar field point => field -> field -> point
pointXY
    (Natural -> FFA Ed25519_Base 'Auto ctx
forall a b. FromConstant a b => a -> b
fromConstant (Natural
15112221349535400772501151409588531511454012693041857206046113283949847762202 :: Natural))
    (Natural -> FFA Ed25519_Base 'Auto ctx
forall a b. FromConstant a b => a -> b
fromConstant (Natural
46316835694926478169428394003475163141307993866256225615783033603165251855960 :: Natural))

instance
  ( Symbolic ctx
  , a ~ BaseField ctx
  , bits ~ NumberOfBits a
  , KnownFFA Ed25519_Base 'Auto ctx
  ) => Scale (FieldElement ctx) (Ed25519_Point ctx) where

    scale :: FieldElement ctx -> Ed25519_Point ctx -> Ed25519_Point ctx
scale FieldElement ctx
sc Ed25519_Point ctx
x = [Ed25519_Point ctx] -> Ed25519_Point ctx
forall (t :: Type -> Type) a.
(Foldable t, AdditiveMonoid a) =>
t a -> a
sum ([Ed25519_Point ctx] -> Ed25519_Point ctx)
-> [Ed25519_Point ctx] -> Ed25519_Point ctx
forall a b. (a -> b) -> a -> b
$ (Natural -> Ed25519_Point ctx -> Ed25519_Point ctx)
-> [Natural] -> [Ed25519_Point ctx] -> [Ed25519_Point ctx]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
P.zipWith (\Natural
b Ed25519_Point ctx
p -> forall b a. Conditional b a => a -> a -> b -> a
bool @(Bool ctx) Ed25519_Point ctx
forall a. AdditiveMonoid a => a
zero Ed25519_Point ctx
p (ByteString bits ctx -> Natural -> Bool ctx
forall (c :: (Type -> Type) -> Type) (n :: Natural).
Symbolic c =>
ByteString n c -> Natural -> Bool c
isSet ByteString bits ctx
bits Natural
b)) [Natural
upper, Natural
upper Natural -> Natural -> Natural
-! Natural
1 .. Natural
0] ((Ed25519_Point ctx -> Ed25519_Point ctx)
-> Ed25519_Point ctx -> [Ed25519_Point ctx]
forall a. (a -> a) -> a -> [a]
P.iterate (\Ed25519_Point ctx
e -> Ed25519_Point ctx
e Ed25519_Point ctx -> Ed25519_Point ctx -> Ed25519_Point ctx
forall a. AdditiveSemigroup a => a -> a -> a
+ Ed25519_Point ctx
e) Ed25519_Point ctx
x)
        where
            bits :: ByteString bits ctx
            bits :: ByteString bits ctx
bits = ctx (Vector bits) -> ByteString bits ctx
forall (n :: Natural) (context :: (Type -> Type) -> Type).
context (Vector n) -> ByteString n context
ByteString (ctx (Vector bits) -> ByteString bits ctx)
-> ctx (Vector bits) -> ByteString bits ctx
forall a b. (a -> b) -> a -> b
$ FieldElement ctx -> Bits (FieldElement ctx)
forall a. BinaryExpansion a => a -> Bits a
binaryExpansion FieldElement ctx
sc

            upper :: Natural
            upper :: Natural
upper = forall (n :: Natural). KnownNat n => Natural
value @bits Natural -> Natural -> Natural
-! Natural
1