{-# LANGUAGE TypeOperators        #-}
{-# LANGUAGE UndecidableInstances #-}

module ZkFold.Base.Protocol.Plonkup.Internal where

import           Data.Constraint                                     (withDict)
import           Data.Constraint.Nat                                 (plusNat, timesNat)
import           Data.Functor.Classes                                (Show1)
import           Data.Functor.Rep                                    (Rep)
import           Prelude                                             hiding (Num (..), drop, length, sum, take, (!!),
                                                                      (/), (^))
import           Test.QuickCheck                                     (Arbitrary (..))

import           ZkFold.Base.Algebra.Basic.Number
import           ZkFold.Base.Algebra.EllipticCurve.Class             (CyclicGroup (..))
import           ZkFold.Base.Algebra.Polynomials.Univariate          (PolyVec)
import           ZkFold.Base.Protocol.Plonkup.Utils
import           ZkFold.Symbolic.Compiler                            ()
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal

{-
    NOTE: we need to parametrize the type of transcripts because we use BuiltinByteString on-chain and ByteString off-chain.
    Additionally, we don't want this library to depend on Cardano libraries.
-}

data Plonkup p i (n :: Natural) l g1 g2 transcript = Plonkup {
        forall {k} {k} (p :: Type -> Type) (i :: Type -> Type)
       (n :: Natural) (l :: Type -> Type) g1 (g2 :: k) (transcript :: k).
Plonkup p i n l g1 g2 transcript -> ScalarFieldOf g1
omega :: ScalarFieldOf g1,
        forall {k} {k} (p :: Type -> Type) (i :: Type -> Type)
       (n :: Natural) (l :: Type -> Type) g1 (g2 :: k) (transcript :: k).
Plonkup p i n l g1 g2 transcript -> ScalarFieldOf g1
k1    :: ScalarFieldOf g1,
        forall {k} {k} (p :: Type -> Type) (i :: Type -> Type)
       (n :: Natural) (l :: Type -> Type) g1 (g2 :: k) (transcript :: k).
Plonkup p i n l g1 g2 transcript -> ScalarFieldOf g1
k2    :: ScalarFieldOf g1,
        forall {k} {k} (p :: Type -> Type) (i :: Type -> Type)
       (n :: Natural) (l :: Type -> Type) g1 (g2 :: k) (transcript :: k).
Plonkup p i n l g1 g2 transcript
-> ArithmeticCircuit (ScalarFieldOf g1) p i l
ac    :: ArithmeticCircuit (ScalarFieldOf g1) p i l,
        forall {k} {k} (p :: Type -> Type) (i :: Type -> Type)
       (n :: Natural) (l :: Type -> Type) g1 (g2 :: k) (transcript :: k).
Plonkup p i n l g1 g2 transcript -> ScalarFieldOf g1
x     :: ScalarFieldOf g1
    }

type PlonkupPermutationSize n = 3 * n

-- The maximum degree of the polynomials we need in the protocol is `4 * n + 5`.
type PlonkupPolyExtendedLength n = 4 * n + 6

with4n6 :: forall n {r}. KnownNat n => (KnownNat (4 * n + 6) => r) -> r
with4n6 :: forall (n :: Natural) {r}.
KnownNat n =>
(KnownNat ((4 * n) + 6) => r) -> r
with4n6 KnownNat ((4 * n) + 6) => r
f = ((KnownNat 4, KnownNat n) :- KnownNat (4 * n))
-> (KnownNat (4 * n) => r) -> r
forall (c :: Constraint) e r. HasDict c e => e -> (c => r) -> r
withDict (forall (n :: Natural) (m :: Natural).
(KnownNat n, KnownNat m) :- KnownNat (n * m)
timesNat @4 @n) (((KnownNat (4 * n), KnownNat 6) :- KnownNat ((4 * n) + 6))
-> (KnownNat ((4 * n) + 6) => r) -> r
forall (c :: Constraint) e r. HasDict c e => e -> (c => r) -> r
withDict (forall (n :: Natural) (m :: Natural).
(KnownNat n, KnownNat m) :- KnownNat (n + m)
plusNat @(4 * n) @6) r
KnownNat ((4 * n) + 6) => r
f)

type PlonkupPolyExtended n g = PolyVec (ScalarFieldOf g) (PlonkupPolyExtendedLength n)

instance (Show (ScalarFieldOf g1), Show (Rep i), Show1 l, Ord (Rep i)) => Show (Plonkup p i n l g1 g2 t) where
    show :: Plonkup p i n l g1 g2 t -> String
show Plonkup {ArithmeticCircuit (ScalarFieldOf g1) p i l
ScalarFieldOf g1
omega :: forall {k} {k} (p :: Type -> Type) (i :: Type -> Type)
       (n :: Natural) (l :: Type -> Type) g1 (g2 :: k) (transcript :: k).
Plonkup p i n l g1 g2 transcript -> ScalarFieldOf g1
k1 :: forall {k} {k} (p :: Type -> Type) (i :: Type -> Type)
       (n :: Natural) (l :: Type -> Type) g1 (g2 :: k) (transcript :: k).
Plonkup p i n l g1 g2 transcript -> ScalarFieldOf g1
k2 :: forall {k} {k} (p :: Type -> Type) (i :: Type -> Type)
       (n :: Natural) (l :: Type -> Type) g1 (g2 :: k) (transcript :: k).
Plonkup p i n l g1 g2 transcript -> ScalarFieldOf g1
ac :: forall {k} {k} (p :: Type -> Type) (i :: Type -> Type)
       (n :: Natural) (l :: Type -> Type) g1 (g2 :: k) (transcript :: k).
Plonkup p i n l g1 g2 transcript
-> ArithmeticCircuit (ScalarFieldOf g1) p i l
x :: forall {k} {k} (p :: Type -> Type) (i :: Type -> Type)
       (n :: Natural) (l :: Type -> Type) g1 (g2 :: k) (transcript :: k).
Plonkup p i n l g1 g2 transcript -> ScalarFieldOf g1
omega :: ScalarFieldOf g1
k1 :: ScalarFieldOf g1
k2 :: ScalarFieldOf g1
ac :: ArithmeticCircuit (ScalarFieldOf g1) p i l
x :: ScalarFieldOf g1
..} =
        String
"Plonkup: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ScalarFieldOf g1 -> String
forall a. Show a => a -> String
show ScalarFieldOf g1
omega String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ScalarFieldOf g1 -> String
forall a. Show a => a -> String
show ScalarFieldOf g1
k1 String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ScalarFieldOf g1 -> String
forall a. Show a => a -> String
show ScalarFieldOf g1
k2 String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" " String -> ShowS
forall a. [a] -> [a] -> [a]
++ l (Var (ScalarFieldOf g1) i) -> String
forall a. Show a => a -> String
show (ArithmeticCircuit (ScalarFieldOf g1) p i l
-> l (Var (ScalarFieldOf g1) i)
forall a (p :: Type -> Type) (i :: Type -> Type)
       (o :: Type -> Type).
ArithmeticCircuit a p i o -> o (Var a i)
acOutput ArithmeticCircuit (ScalarFieldOf g1) p i l
ac)  String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ArithmeticCircuit (ScalarFieldOf g1) p i l -> String
forall a. Show a => a -> String
show ArithmeticCircuit (ScalarFieldOf g1) p i l
ac String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ScalarFieldOf g1 -> String
forall a. Show a => a -> String
show ScalarFieldOf g1
x

instance
  ( KnownNat n, Arithmetic (ScalarFieldOf g1), Arbitrary (ScalarFieldOf g1)
  , Arbitrary (ArithmeticCircuit (ScalarFieldOf g1) p i l)
  ) => Arbitrary (Plonkup p i n l g1 g2 t) where
    arbitrary :: Gen (Plonkup p i n l g1 g2 t)
arbitrary = do
        ArithmeticCircuit (ScalarFieldOf g1) p i l
ac <- Gen (ArithmeticCircuit (ScalarFieldOf g1) p i l)
forall a. Arbitrary a => Gen a
arbitrary
        let (ScalarFieldOf g1
omega, ScalarFieldOf g1
k1, ScalarFieldOf g1
k2) = Natural -> (ScalarFieldOf g1, ScalarFieldOf g1, ScalarFieldOf g1)
forall a. (Eq a, FiniteField a) => Natural -> (a, a, a)
getParams (forall (n :: Natural). KnownNat n => Natural
value @n)
        ScalarFieldOf g1
-> ScalarFieldOf g1
-> ScalarFieldOf g1
-> ArithmeticCircuit (ScalarFieldOf g1) p i l
-> ScalarFieldOf g1
-> Plonkup p i n l g1 g2 t
forall {k} {k} (p :: Type -> Type) (i :: Type -> Type)
       (n :: Natural) (l :: Type -> Type) g1 (g2 :: k) (transcript :: k).
ScalarFieldOf g1
-> ScalarFieldOf g1
-> ScalarFieldOf g1
-> ArithmeticCircuit (ScalarFieldOf g1) p i l
-> ScalarFieldOf g1
-> Plonkup p i n l g1 g2 transcript
Plonkup ScalarFieldOf g1
omega ScalarFieldOf g1
k1 ScalarFieldOf g1
k2 ArithmeticCircuit (ScalarFieldOf g1) p i l
ac (ScalarFieldOf g1 -> Plonkup p i n l g1 g2 t)
-> Gen (ScalarFieldOf g1) -> Gen (Plonkup p i n l g1 g2 t)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Gen (ScalarFieldOf g1)
forall a. Arbitrary a => Gen a
arbitrary