{-# LANGUAGE TypeOperators, RankNTypes #-}

-- | Calculation of roots of unity for the forward and inverse DFT\/FFT.
module Data.Array.Repa.Algorithms.DFT.Roots
        ( calcRootsOfUnityP
        , calcInverseRootsOfUnityP)
where
import Data.Array.Repa
import Data.Array.Repa.Algorithms.Complex


-- | Calculate roots of unity for the forward transform.
calcRootsOfUnityP
        :: (Shape sh, Monad m)
        => (sh :. Int)                  -- ^ Length of lowest dimension of result.
        -> m (Array U (sh :. Int) Complex)

calcRootsOfUnityP :: (sh :. Int) -> m (Array U (sh :. Int) Complex)
calcRootsOfUnityP sh :: sh :. Int
sh@(sh
_ :. Int
n) 
 = Array D (sh :. Int) Complex -> m (Array U (sh :. Int) Complex)
forall r1 sh e r2 (m :: * -> *).
(Load r1 sh e, Target r2 e, Source r2 e, Monad m) =>
Array r1 sh e -> m (Array r2 sh e)
computeP (Array D (sh :. Int) Complex -> m (Array U (sh :. Int) Complex))
-> Array D (sh :. Int) Complex -> m (Array U (sh :. Int) Complex)
forall a b. (a -> b) -> a -> b
$ (sh :. Int)
-> ((sh :. Int) -> Complex) -> Array D (sh :. Int) Complex
forall sh a. sh -> (sh -> a) -> Array D sh a
fromFunction sh :. Int
sh (sh :. Int) -> Complex
forall sh. Shape sh => (sh :. Int) -> Complex
f
 where
    f :: Shape sh => (sh :. Int) -> Complex
    f :: (sh :. Int) -> Complex
f (sh
_ :. Int
i) 
        = ( Double -> Double
forall a. Floating a => a -> a
cos  (Double
2 Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
forall a. Floating a => a
pi Double -> Double -> Double
forall a. Num a => a -> a -> a
* (Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
len)
          , - Double -> Double
forall a. Floating a => a -> a
sin  (Double
2 Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
forall a. Floating a => a
pi Double -> Double -> Double
forall a. Num a => a -> a -> a
* (Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
len))

    len :: Double
len = Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n


-- | Calculate roots of unity for the inverse transform.
calcInverseRootsOfUnityP
        :: (Shape sh, Monad m)
        => (sh :. Int)                  -- ^ Length of lowest dimension of result.
        -> m (Array U (sh :. Int) Complex)

calcInverseRootsOfUnityP :: (sh :. Int) -> m (Array U (sh :. Int) Complex)
calcInverseRootsOfUnityP sh :: sh :. Int
sh@(sh
_ :. Int
n) 
 = Array D (sh :. Int) Complex -> m (Array U (sh :. Int) Complex)
forall r1 sh e r2 (m :: * -> *).
(Load r1 sh e, Target r2 e, Source r2 e, Monad m) =>
Array r1 sh e -> m (Array r2 sh e)
computeP (Array D (sh :. Int) Complex -> m (Array U (sh :. Int) Complex))
-> Array D (sh :. Int) Complex -> m (Array U (sh :. Int) Complex)
forall a b. (a -> b) -> a -> b
$ (sh :. Int)
-> ((sh :. Int) -> Complex) -> Array D (sh :. Int) Complex
forall sh a. sh -> (sh -> a) -> Array D sh a
fromFunction sh :. Int
sh (sh :. Int) -> Complex
forall sh. Shape sh => (sh :. Int) -> Complex
f
 where
    f :: Shape sh => (sh :. Int) -> Complex
    f :: (sh :. Int) -> Complex
f (sh
_ :. Int
i) 
        = ( Double -> Double
forall a. Floating a => a -> a
cos  (Double
2 Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
forall a. Floating a => a
pi Double -> Double -> Double
forall a. Num a => a -> a -> a
* (Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
len)
          , Double -> Double
forall a. Floating a => a -> a
sin  (Double
2 Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
forall a. Floating a => a
pi Double -> Double -> Double
forall a. Num a => a -> a -> a
* (Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
len))

    len :: Double
len = Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n