{-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} module ZkFold.Symbolic.Algorithms.FFT ( fft , ifft ) where import Data.Maybe (fromJust) import qualified Data.Vector as V import Prelude (pure, ($), (.)) import qualified Prelude as P import ZkFold.Base.Algebra.Basic.Class import ZkFold.Base.Algebra.Basic.Number import ZkFold.Base.Data.HFunctor (hmap) import ZkFold.Base.Data.Vector (Vector (..), toV) import ZkFold.Symbolic.Class import ZkFold.Symbolic.MonadCircuit (MonadCircuit (..), newAssigned) fft :: forall ctx n . (Symbolic ctx, KnownNat n) => ctx (Vector (2^n)) -> ctx (Vector (2^n)) fft :: forall (ctx :: (Type -> Type) -> Type) (n :: Nat). (Symbolic ctx, KnownNat n) => ctx (Vector (2 ^ n)) -> ctx (Vector (2 ^ n)) fft ctx (Vector (2 ^ n)) v = (forall a. Vector a -> Vector (2 ^ n) a) -> ctx Vector -> ctx (Vector (2 ^ n)) forall {k} (c :: (k -> Type) -> Type) (f :: k -> Type) (g :: k -> Type). HFunctor c => (forall (a :: k). f a -> g a) -> c f -> c g forall (f :: Type -> Type) (g :: Type -> Type). (forall a. f a -> g a) -> ctx f -> ctx g hmap Vector a -> Vector (2 ^ n) a forall (size :: Nat) a. Vector a -> Vector size a forall a. Vector a -> Vector (2 ^ n) a Vector (ctx Vector -> ctx (Vector (2 ^ n))) -> ctx Vector -> ctx (Vector (2 ^ n)) forall a b. (a -> b) -> a -> b $ ctx (Vector (2 ^ n)) -> CircuitFun '[Vector (2 ^ n)] Vector ctx -> ctx Vector forall (f :: Type -> Type) (g :: Type -> Type). ctx f -> CircuitFun '[f] g ctx -> ctx g forall (c :: (Type -> Type) -> Type) (f :: Type -> Type) (g :: Type -> Type). Symbolic c => c f -> CircuitFun '[f] g c -> c g fromCircuitF ctx (Vector (2 ^ n)) v (Nat -> BaseField ctx -> BaseField ctx -> Vector i -> m (Vector i) forall a i w (m :: Type -> Type). (Arithmetic a, MonadCircuit i a w m) => Nat -> a -> a -> Vector i -> m (Vector i) fft' (forall (n :: Nat). KnownNat n => Nat value @n) BaseField ctx u BaseField ctx forall a. MultiplicativeMonoid a => a one (Vector i -> m (Vector i)) -> (Vector (2 ^ n) i -> Vector i) -> Vector (2 ^ n) i -> m (Vector i) forall b c a. (b -> c) -> (a -> b) -> a -> c . Vector (2 ^ n) i -> Vector i forall (size :: Nat) a. Vector size a -> Vector a toV) where u :: BaseField ctx u :: BaseField ctx u = (Maybe (BaseField ctx) -> BaseField ctx forall a. HasCallStack => Maybe a -> a fromJust (Maybe (BaseField ctx) -> BaseField ctx) -> Maybe (BaseField ctx) -> BaseField ctx forall a b. (a -> b) -> a -> b $ Nat -> Maybe (BaseField ctx) forall a. Field a => Nat -> Maybe a rootOfUnity (forall (n :: Nat). KnownNat n => Nat value @n)) ifft :: forall ctx n . (Symbolic ctx, KnownNat n) => ctx (Vector (2^n)) -> ctx (Vector (2^n)) ifft :: forall (ctx :: (Type -> Type) -> Type) (n :: Nat). (Symbolic ctx, KnownNat n) => ctx (Vector (2 ^ n)) -> ctx (Vector (2 ^ n)) ifft ctx (Vector (2 ^ n)) v = (forall a. Vector a -> Vector (2 ^ n) a) -> ctx Vector -> ctx (Vector (2 ^ n)) forall {k} (c :: (k -> Type) -> Type) (f :: k -> Type) (g :: k -> Type). HFunctor c => (forall (a :: k). f a -> g a) -> c f -> c g forall (f :: Type -> Type) (g :: Type -> Type). (forall a. f a -> g a) -> ctx f -> ctx g hmap Vector a -> Vector (2 ^ n) a forall (size :: Nat) a. Vector a -> Vector size a forall a. Vector a -> Vector (2 ^ n) a Vector (ctx Vector -> ctx (Vector (2 ^ n))) -> ctx Vector -> ctx (Vector (2 ^ n)) forall a b. (a -> b) -> a -> b $ ctx (Vector (2 ^ n)) -> CircuitFun '[Vector (2 ^ n)] Vector ctx -> ctx Vector forall (f :: Type -> Type) (g :: Type -> Type). ctx f -> CircuitFun '[f] g ctx -> ctx g forall (c :: (Type -> Type) -> Type) (f :: Type -> Type) (g :: Type -> Type). Symbolic c => c f -> CircuitFun '[f] g c -> c g fromCircuitF ctx (Vector (2 ^ n)) v (Nat -> BaseField ctx -> BaseField ctx -> Vector i -> m (Vector i) forall a i w (m :: Type -> Type). (Arithmetic a, MonadCircuit i a w m) => Nat -> a -> a -> Vector i -> m (Vector i) fft' (forall (n :: Nat). KnownNat n => Nat value @n) BaseField ctx u BaseField ctx nInv (Vector i -> m (Vector i)) -> (Vector (2 ^ n) i -> Vector i) -> Vector (2 ^ n) i -> m (Vector i) forall b c a. (b -> c) -> (a -> b) -> a -> c . Vector (2 ^ n) i -> Vector i forall (size :: Nat) a. Vector size a -> Vector a toV) where u :: BaseField ctx u :: BaseField ctx u = (BaseField ctx forall a. MultiplicativeMonoid a => a one BaseField ctx -> BaseField ctx -> BaseField ctx forall a. Field a => a -> a -> a // Maybe (BaseField ctx) -> BaseField ctx forall a. HasCallStack => Maybe a -> a fromJust (Nat -> Maybe (BaseField ctx) forall a. Field a => Nat -> Maybe a rootOfUnity (forall (n :: Nat). KnownNat n => Nat value @n))) nInv :: BaseField ctx nInv :: BaseField ctx nInv = BaseField ctx forall a. MultiplicativeMonoid a => a one BaseField ctx -> BaseField ctx -> BaseField ctx forall a. Field a => a -> a -> a // Nat -> BaseField ctx forall a b. FromConstant a b => a -> b fromConstant ((Nat 2 :: Natural) Nat -> Nat -> Nat forall a b. Exponent a b => a -> b -> a ^ forall (n :: Nat). KnownNat n => Nat value @n) fft' :: forall a i w m . Arithmetic a => MonadCircuit i a w m => Natural -> a -> a -> V.Vector i -> m (V.Vector i) fft' :: forall a i w (m :: Type -> Type). (Arithmetic a, MonadCircuit i a w m) => Nat -> a -> a -> Vector i -> m (Vector i) fft' Nat 0 a _ a _ Vector i v = Vector i -> m (Vector i) forall a. a -> m a forall (f :: Type -> Type) a. Applicative f => a -> f a pure Vector i v fft' Nat n a wn a s Vector i v = do Vector i a0Hat <- Nat -> a -> a -> Vector i -> m (Vector i) forall a i w (m :: Type -> Type). (Arithmetic a, MonadCircuit i a w m) => Nat -> a -> a -> Vector i -> m (Vector i) fft' (Nat n Nat -> Nat -> Nat -! Nat 1) a wn2 a forall a. MultiplicativeMonoid a => a one Vector i a0 Vector i a1Hat <- Nat -> a -> a -> Vector i -> m (Vector i) forall a i w (m :: Type -> Type). (Arithmetic a, MonadCircuit i a w m) => Nat -> a -> a -> Vector i -> m (Vector i) fft' (Nat n Nat -> Nat -> Nat -! Nat 1) a wn2 a forall a. MultiplicativeMonoid a => a one Vector i a1 Int -> (Int -> m i) -> m (Vector i) forall (m :: Type -> Type) a. Monad m => Int -> (Int -> m a) -> m (Vector a) V.generateM (Nat -> Int forall a b. (Integral a, Num b) => a -> b P.fromIntegral (Nat -> Int) -> Nat -> Int forall a b. (a -> b) -> a -> b $ (Nat 2 :: Natural) Nat -> Nat -> Nat forall a b. Exponent a b => a -> b -> a ^ Nat n) ((Int -> m i) -> m (Vector i)) -> (Int -> m i) -> m (Vector i) forall a b. (a -> b) -> a -> b $ \Int ix -> do let arrIx :: Int arrIx = Int ix Int -> Int -> Int forall a. Integral a => a -> a -> a `P.mod` Int halfLen op :: AdditiveGroup p => p -> p -> p op :: forall p. AdditiveGroup p => p -> p -> p op = if Int ix Int -> Int -> Bool forall a. Ord a => a -> a -> Bool P.< Int halfLen then p -> p -> p forall a. AdditiveSemigroup a => a -> a -> a (+) else (-) a0k :: i a0k = Vector i a0Hat Vector i -> Int -> i forall a. Vector a -> Int -> a `V.unsafeIndex` Int arrIx a1k :: i a1k = Vector i a1Hat Vector i -> Int -> i forall a. Vector a -> Int -> a `V.unsafeIndex` Int arrIx wnp :: a wnp = a wn a -> Nat -> a forall a b. Exponent a b => a -> b -> a ^ (Int -> Nat forall a b. (Integral a, Num b) => a -> b P.fromIntegral Int arrIx :: Natural) ClosedPoly i a -> m i forall var a w (m :: Type -> Type). MonadCircuit var a w m => ClosedPoly var a -> m var newAssigned (ClosedPoly i a -> m i) -> ClosedPoly i a -> m i forall a b. (a -> b) -> a -> b $ \i -> x p -> a -> x -> x forall b a. Scale b a => b -> a -> a scale a s (i -> x p i a0k) x -> x -> x forall p. AdditiveGroup p => p -> p -> p `op` a -> x -> x forall b a. Scale b a => b -> a -> a scale (a s a -> a -> a forall a. MultiplicativeSemigroup a => a -> a -> a * a wnp) (i -> x p i a1k) where a0 :: Vector i a0 = (Int -> i -> Bool) -> Vector i -> Vector i forall a. (Int -> a -> Bool) -> Vector a -> Vector a V.ifilter (\Int i i _ -> Int i Int -> Int -> Int forall a. Integral a => a -> a -> a `P.mod` Int 2 Int -> Int -> Bool forall a. Eq a => a -> a -> Bool P.== Int 0) Vector i v a1 :: Vector i a1 = (Int -> i -> Bool) -> Vector i -> Vector i forall a. (Int -> a -> Bool) -> Vector a -> Vector a V.ifilter (\Int i i _ -> Int i Int -> Int -> Int forall a. Integral a => a -> a -> a `P.mod` Int 2 Int -> Int -> Bool forall a. Eq a => a -> a -> Bool P.== Int 1) Vector i v wn2 :: a wn2 = a wn a -> a -> a forall a. MultiplicativeSemigroup a => a -> a -> a * a wn halfLen :: Int halfLen = Nat -> Int forall a b. (Integral a, Num b) => a -> b P.fromIntegral (Nat -> Int) -> Nat -> Int forall a b. (a -> b) -> a -> b $ (Nat 2 :: Natural) Nat -> Nat -> Nat forall a b. Exponent a b => a -> b -> a ^ (Nat n Nat -> Nat -> Nat -! Nat 1)