{-# LANGUAGE RebindableSyntax     #-}
{-# LANGUAGE TypeOperators        #-}
{-# LANGUAGE UndecidableInstances #-}

{-# OPTIONS_GHC -Wno-orphans #-}

module ZkFold.Symbolic.Data.BLS12_381 (BLS12_381_G1_Point) where

import           Prelude                                     (fromInteger, type (~), ($))
import qualified Prelude                                     as P

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Base.Algebra.Basic.Number
import           ZkFold.Base.Algebra.EllipticCurve.Class
import           ZkFold.Base.Algebra.EllipticCurve.BLS12_381 (BLS12_381_Base)
import           ZkFold.Symbolic.Class                       (Symbolic (..))
import           ZkFold.Symbolic.Data.Bool
import           ZkFold.Symbolic.Data.ByteString
import           ZkFold.Symbolic.Data.Combinators            (RegisterSize (Auto))
import           ZkFold.Symbolic.Data.Conditional
import           ZkFold.Symbolic.Data.FFA
import           ZkFold.Symbolic.Data.FieldElement

type BLS12_381_G1_Point ctx = Weierstrass "BLS12-381-G1" (Point (FFA BLS12_381_Base 'Auto ctx))

instance (Symbolic ctx, KnownFFA BLS12_381_Base 'Auto ctx) => CyclicGroup (BLS12_381_G1_Point ctx) where
  type ScalarFieldOf (BLS12_381_G1_Point ctx) = FieldElement ctx
  pointGen :: BLS12_381_G1_Point ctx
pointGen = FFA BLS12_381_Base 'Auto ctx
-> FFA BLS12_381_Base 'Auto ctx -> BLS12_381_G1_Point ctx
forall field point. Planar field point => field -> field -> point
pointXY
    (Natural -> FFA BLS12_381_Base 'Auto ctx
forall a b. FromConstant a b => a -> b
fromConstant (Natural
0x17f1d3a73197d7942695638c4fa9ac0fc3688c4f9774b905a14e3a3f171bac586c55e83ff97a1aeffb3af00adb22c6bb :: Natural))
    (Natural -> FFA BLS12_381_Base 'Auto ctx
forall a b. FromConstant a b => a -> b
fromConstant (Natural
0x8b3f481e3aaa0f1a09e30ed741d8ae4fcf5e095d5d00af600db18cb2c04b3edd03cc744a2888ae40caa232946c5e7e1 :: Natural))

instance
  ( Symbolic ctx
  , a ~ BaseField ctx
  , bits ~ NumberOfBits a
  , KnownFFA BLS12_381_Base 'Auto ctx
  ) => Scale (FieldElement ctx) (BLS12_381_G1_Point ctx) where

    scale :: FieldElement ctx
-> BLS12_381_G1_Point ctx -> BLS12_381_G1_Point ctx
scale FieldElement ctx
sc BLS12_381_G1_Point ctx
x = [BLS12_381_G1_Point ctx] -> BLS12_381_G1_Point ctx
forall (t :: Type -> Type) a.
(Foldable t, AdditiveMonoid a) =>
t a -> a
sum ([BLS12_381_G1_Point ctx] -> BLS12_381_G1_Point ctx)
-> [BLS12_381_G1_Point ctx] -> BLS12_381_G1_Point ctx
forall a b. (a -> b) -> a -> b
$ (Natural -> BLS12_381_G1_Point ctx -> BLS12_381_G1_Point ctx)
-> [Natural]
-> [BLS12_381_G1_Point ctx]
-> [BLS12_381_G1_Point ctx]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
P.zipWith (\Natural
b BLS12_381_G1_Point ctx
p -> forall b a. Conditional b a => a -> a -> b -> a
bool @(Bool ctx) BLS12_381_G1_Point ctx
forall a. AdditiveMonoid a => a
zero BLS12_381_G1_Point ctx
p (ByteString bits ctx -> Natural -> Bool ctx
forall (c :: (Type -> Type) -> Type) (n :: Natural).
Symbolic c =>
ByteString n c -> Natural -> Bool c
isSet ByteString bits ctx
bits Natural
b)) [Natural
upper, Natural
upper Natural -> Natural -> Natural
-! Natural
1 .. Natural
0] ((BLS12_381_G1_Point ctx -> BLS12_381_G1_Point ctx)
-> BLS12_381_G1_Point ctx -> [BLS12_381_G1_Point ctx]
forall a. (a -> a) -> a -> [a]
P.iterate (\BLS12_381_G1_Point ctx
e -> BLS12_381_G1_Point ctx
e BLS12_381_G1_Point ctx
-> BLS12_381_G1_Point ctx -> BLS12_381_G1_Point ctx
forall a. AdditiveSemigroup a => a -> a -> a
+ BLS12_381_G1_Point ctx
e) BLS12_381_G1_Point ctx
x)
        where
            bits :: ByteString bits ctx
            bits :: ByteString bits ctx
bits = ctx (Vector bits) -> ByteString bits ctx
forall (n :: Natural) (context :: (Type -> Type) -> Type).
context (Vector n) -> ByteString n context
ByteString (ctx (Vector bits) -> ByteString bits ctx)
-> ctx (Vector bits) -> ByteString bits ctx
forall a b. (a -> b) -> a -> b
$ FieldElement ctx -> Bits (FieldElement ctx)
forall a. BinaryExpansion a => a -> Bits a
binaryExpansion FieldElement ctx
sc

            upper :: Natural
            upper :: Natural
upper = forall (n :: Natural). KnownNat n => Natural
value @bits Natural -> Natural -> Natural
-! Natural
1