{-# LANGUAGE DeriveAnyClass       #-}
{-# LANGUAGE TypeOperators        #-}
{-# LANGUAGE UndecidableInstances #-}

module ZkFold.Symbolic.Algorithms.FFT
    ( fft
    , ifft
    ) where

import           Data.Maybe                       (fromJust)
import qualified Data.Vector                      as V
import           Prelude                          (pure, ($), (.))
import qualified Prelude                          as P

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Base.Algebra.Basic.Number
import           ZkFold.Base.Data.HFunctor        (hmap)
import           ZkFold.Base.Data.Vector          (Vector (..), toV)
import           ZkFold.Symbolic.Class
import           ZkFold.Symbolic.MonadCircuit     (MonadCircuit (..), newAssigned)

fft :: forall ctx n . (Symbolic ctx, KnownNat n) => ctx (Vector (2^n)) -> ctx (Vector (2^n))
fft :: forall (ctx :: (Type -> Type) -> Type) (n :: Nat).
(Symbolic ctx, KnownNat n) =>
ctx (Vector (2 ^ n)) -> ctx (Vector (2 ^ n))
fft ctx (Vector (2 ^ n))
v = (forall a. Vector a -> Vector (2 ^ n) a)
-> ctx Vector -> ctx (Vector (2 ^ n))
forall {k} (c :: (k -> Type) -> Type) (f :: k -> Type)
       (g :: k -> Type).
HFunctor c =>
(forall (a :: k). f a -> g a) -> c f -> c g
forall (f :: Type -> Type) (g :: Type -> Type).
(forall a. f a -> g a) -> ctx f -> ctx g
hmap Vector a -> Vector (2 ^ n) a
forall (size :: Nat) a. Vector a -> Vector size a
forall a. Vector a -> Vector (2 ^ n) a
Vector (ctx Vector -> ctx (Vector (2 ^ n)))
-> ctx Vector -> ctx (Vector (2 ^ n))
forall a b. (a -> b) -> a -> b
$ ctx (Vector (2 ^ n))
-> CircuitFun '[Vector (2 ^ n)] Vector ctx -> ctx Vector
forall (f :: Type -> Type) (g :: Type -> Type).
ctx f -> CircuitFun '[f] g ctx -> ctx g
forall (c :: (Type -> Type) -> Type) (f :: Type -> Type)
       (g :: Type -> Type).
Symbolic c =>
c f -> CircuitFun '[f] g c -> c g
fromCircuitF ctx (Vector (2 ^ n))
v (Nat -> BaseField ctx -> BaseField ctx -> Vector i -> m (Vector i)
forall a i w (m :: Type -> Type).
(Arithmetic a, MonadCircuit i a w m) =>
Nat -> a -> a -> Vector i -> m (Vector i)
fft' (forall (n :: Nat). KnownNat n => Nat
value @n) BaseField ctx
u BaseField ctx
forall a. MultiplicativeMonoid a => a
one (Vector i -> m (Vector i))
-> (Vector (2 ^ n) i -> Vector i)
-> Vector (2 ^ n) i
-> m (Vector i)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector (2 ^ n) i -> Vector i
forall (size :: Nat) a. Vector size a -> Vector a
toV)
    where
        u :: BaseField ctx
        u :: BaseField ctx
u = (Maybe (BaseField ctx) -> BaseField ctx
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe (BaseField ctx) -> BaseField ctx)
-> Maybe (BaseField ctx) -> BaseField ctx
forall a b. (a -> b) -> a -> b
$ Nat -> Maybe (BaseField ctx)
forall a. Field a => Nat -> Maybe a
rootOfUnity (forall (n :: Nat). KnownNat n => Nat
value @n))

ifft :: forall ctx n . (Symbolic ctx, KnownNat n) => ctx (Vector (2^n)) -> ctx (Vector (2^n))
ifft :: forall (ctx :: (Type -> Type) -> Type) (n :: Nat).
(Symbolic ctx, KnownNat n) =>
ctx (Vector (2 ^ n)) -> ctx (Vector (2 ^ n))
ifft ctx (Vector (2 ^ n))
v = (forall a. Vector a -> Vector (2 ^ n) a)
-> ctx Vector -> ctx (Vector (2 ^ n))
forall {k} (c :: (k -> Type) -> Type) (f :: k -> Type)
       (g :: k -> Type).
HFunctor c =>
(forall (a :: k). f a -> g a) -> c f -> c g
forall (f :: Type -> Type) (g :: Type -> Type).
(forall a. f a -> g a) -> ctx f -> ctx g
hmap Vector a -> Vector (2 ^ n) a
forall (size :: Nat) a. Vector a -> Vector size a
forall a. Vector a -> Vector (2 ^ n) a
Vector (ctx Vector -> ctx (Vector (2 ^ n)))
-> ctx Vector -> ctx (Vector (2 ^ n))
forall a b. (a -> b) -> a -> b
$ ctx (Vector (2 ^ n))
-> CircuitFun '[Vector (2 ^ n)] Vector ctx -> ctx Vector
forall (f :: Type -> Type) (g :: Type -> Type).
ctx f -> CircuitFun '[f] g ctx -> ctx g
forall (c :: (Type -> Type) -> Type) (f :: Type -> Type)
       (g :: Type -> Type).
Symbolic c =>
c f -> CircuitFun '[f] g c -> c g
fromCircuitF ctx (Vector (2 ^ n))
v (Nat -> BaseField ctx -> BaseField ctx -> Vector i -> m (Vector i)
forall a i w (m :: Type -> Type).
(Arithmetic a, MonadCircuit i a w m) =>
Nat -> a -> a -> Vector i -> m (Vector i)
fft' (forall (n :: Nat). KnownNat n => Nat
value @n) BaseField ctx
u BaseField ctx
nInv (Vector i -> m (Vector i))
-> (Vector (2 ^ n) i -> Vector i)
-> Vector (2 ^ n) i
-> m (Vector i)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector (2 ^ n) i -> Vector i
forall (size :: Nat) a. Vector size a -> Vector a
toV)
    where
        u :: BaseField ctx
        u :: BaseField ctx
u = (BaseField ctx
forall a. MultiplicativeMonoid a => a
one BaseField ctx -> BaseField ctx -> BaseField ctx
forall a. Field a => a -> a -> a
// Maybe (BaseField ctx) -> BaseField ctx
forall a. HasCallStack => Maybe a -> a
fromJust (Nat -> Maybe (BaseField ctx)
forall a. Field a => Nat -> Maybe a
rootOfUnity (forall (n :: Nat). KnownNat n => Nat
value @n)))

        nInv :: BaseField ctx
        nInv :: BaseField ctx
nInv = BaseField ctx
forall a. MultiplicativeMonoid a => a
one BaseField ctx -> BaseField ctx -> BaseField ctx
forall a. Field a => a -> a -> a
// Nat -> BaseField ctx
forall a b. FromConstant a b => a -> b
fromConstant ((Nat
2 :: Natural) Nat -> Nat -> Nat
forall a b. Exponent a b => a -> b -> a
^ forall (n :: Nat). KnownNat n => Nat
value @n)

fft'
    :: forall a i w m
    .  Arithmetic a
    => MonadCircuit i a w m
    => Natural
    -> a
    -> a
    -> V.Vector i
    -> m (V.Vector i)
fft' :: forall a i w (m :: Type -> Type).
(Arithmetic a, MonadCircuit i a w m) =>
Nat -> a -> a -> Vector i -> m (Vector i)
fft' Nat
0 a
_ a
_ Vector i
v  = Vector i -> m (Vector i)
forall a. a -> m a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Vector i
v
fft' Nat
n a
wn a
s Vector i
v = do
    Vector i
a0Hat <- Nat -> a -> a -> Vector i -> m (Vector i)
forall a i w (m :: Type -> Type).
(Arithmetic a, MonadCircuit i a w m) =>
Nat -> a -> a -> Vector i -> m (Vector i)
fft' (Nat
n Nat -> Nat -> Nat
-! Nat
1) a
wn2 a
forall a. MultiplicativeMonoid a => a
one Vector i
a0
    Vector i
a1Hat <- Nat -> a -> a -> Vector i -> m (Vector i)
forall a i w (m :: Type -> Type).
(Arithmetic a, MonadCircuit i a w m) =>
Nat -> a -> a -> Vector i -> m (Vector i)
fft' (Nat
n Nat -> Nat -> Nat
-! Nat
1) a
wn2 a
forall a. MultiplicativeMonoid a => a
one Vector i
a1
    Int -> (Int -> m i) -> m (Vector i)
forall (m :: Type -> Type) a.
Monad m =>
Int -> (Int -> m a) -> m (Vector a)
V.generateM (Nat -> Int
forall a b. (Integral a, Num b) => a -> b
P.fromIntegral (Nat -> Int) -> Nat -> Int
forall a b. (a -> b) -> a -> b
$ (Nat
2 :: Natural) Nat -> Nat -> Nat
forall a b. Exponent a b => a -> b -> a
^ Nat
n) ((Int -> m i) -> m (Vector i)) -> (Int -> m i) -> m (Vector i)
forall a b. (a -> b) -> a -> b
$ \Int
ix -> do
        let arrIx :: Int
arrIx = Int
ix Int -> Int -> Int
forall a. Integral a => a -> a -> a
`P.mod` Int
halfLen

            op :: AdditiveGroup p => p -> p -> p
            op :: forall p. AdditiveGroup p => p -> p -> p
op = if Int
ix Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
P.< Int
halfLen then p -> p -> p
forall a. AdditiveSemigroup a => a -> a -> a
(+) else (-)

            a0k :: i
a0k = Vector i
a0Hat Vector i -> Int -> i
forall a. Vector a -> Int -> a
`V.unsafeIndex` Int
arrIx
            a1k :: i
a1k = Vector i
a1Hat Vector i -> Int -> i
forall a. Vector a -> Int -> a
`V.unsafeIndex` Int
arrIx
            wnp :: a
wnp = a
wn a -> Nat -> a
forall a b. Exponent a b => a -> b -> a
^ (Int -> Nat
forall a b. (Integral a, Num b) => a -> b
P.fromIntegral Int
arrIx :: Natural)
        ClosedPoly i a -> m i
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
ClosedPoly var a -> m var
newAssigned (ClosedPoly i a -> m i) -> ClosedPoly i a -> m i
forall a b. (a -> b) -> a -> b
$ \i -> x
p -> a -> x -> x
forall b a. Scale b a => b -> a -> a
scale a
s (i -> x
p i
a0k) x -> x -> x
forall p. AdditiveGroup p => p -> p -> p
`op` a -> x -> x
forall b a. Scale b a => b -> a -> a
scale (a
s a -> a -> a
forall a. MultiplicativeSemigroup a => a -> a -> a
* a
wnp) (i -> x
p i
a1k)
    where
        a0 :: Vector i
a0 = (Int -> i -> Bool) -> Vector i -> Vector i
forall a. (Int -> a -> Bool) -> Vector a -> Vector a
V.ifilter (\Int
i i
_ -> Int
i Int -> Int -> Int
forall a. Integral a => a -> a -> a
`P.mod` Int
2 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
P.== Int
0) Vector i
v
        a1 :: Vector i
a1 = (Int -> i -> Bool) -> Vector i -> Vector i
forall a. (Int -> a -> Bool) -> Vector a -> Vector a
V.ifilter (\Int
i i
_ -> Int
i Int -> Int -> Int
forall a. Integral a => a -> a -> a
`P.mod` Int
2 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
P.== Int
1) Vector i
v

        wn2 :: a
wn2 = a
wn a -> a -> a
forall a. MultiplicativeSemigroup a => a -> a -> a
* a
wn

        halfLen :: Int
halfLen = Nat -> Int
forall a b. (Integral a, Num b) => a -> b
P.fromIntegral (Nat -> Int) -> Nat -> Int
forall a b. (a -> b) -> a -> b
$ (Nat
2 :: Natural) Nat -> Nat -> Nat
forall a b. Exponent a b => a -> b -> a
^ (Nat
n Nat -> Nat -> Nat
-! Nat
1)