{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Data.Primitive.Contiguous.FFT
( dft
, idft
, overlapDFT
) where
import qualified Prelude
import Data.Eq (Eq((==)))
import Data.Function (($))
import Control.Monad
import Data.Ord
import Control.Monad.ST
import Data.Complex hiding (cis)
import qualified Data.Complex as C
import Data.Primitive.Contiguous
import GHC.Num (Num(..))
import GHC.Float
import GHC.Real
import GHC.Exts (Int)
cis :: Floating a => a -> a -> Complex a
cis k n = C.cis (2 * pi * k / n)
{-# INLINE cis #-}
mkComplex :: x -> x -> Complex x
mkComplex !r !i = r :+ i
{-# INLINE mkComplex #-}
dftMutable :: forall arr x s. (RealFloat x, Contiguous arr, Element arr (Complex x))
=> Mutable arr s (Complex x)
-> ST s (Mutable arr s (Complex x))
dftMutable !mut = do
!sz <- sizeMutable mut
let getII !ix = (ix + sz `Prelude.div` 2) `Prelude.mod` sz
go :: Int
-> Int
-> Complex x
-> ST s ()
go !i !j !acc = if i == sz then return () else if j < sz
then do
let !jj = getII j
atJJ@(r :+ _) <- read mut jj
let real, imag, same :: x
!same = (-2) * pi * (fromIntegral (i * j)) / (fromIntegral sz)
!real = r * cos same
!imag = r * sin same
!val = acc + mkComplex real imag
go i (j + 1) val
else do
let !ii = getII i
!_ <- write mut ii acc :: ST s ()
go (i + 1) 0 0
!_ <- go 0 0 0
return mut
dft :: forall arr x. (RealFloat x, Contiguous arr, Element arr x, Element arr (Complex x))
=> arr x
-> arr (Complex x)
dft !a = runST $ dftInternal a
dftInternal :: forall arr x s. (RealFloat x, Contiguous arr, Element arr x, Element arr (Complex x))
=> arr x
-> ST s (arr (Complex x))
dftInternal !a = do
let !sz = size a
getII !ix = (ix + sz `Prelude.div` 2) `Prelude.mod` sz
!mut <- new sz :: ST s (Mutable arr s (Complex x))
let go :: Int
-> Int
-> Complex x
-> ST s ()
go !i !j !acc = if i == sz then return () else if j < sz
then do
let !jj = getII j
!atJJ = index a jj
real, imag, same :: x
!same = (-2) * pi * (fromIntegral (i * j)) / (fromIntegral sz)
!real = atJJ * cos same
!imag = atJJ * sin same
!val = acc + mkComplex real imag
go i (j + 1) val
else do
let !ii = getII i
!_ <- write mut ii acc :: ST s ()
go (i + 1) 0 0
!_ <- go 0 0 0
unsafeFreeze mut
idft :: forall arr x. (RealFloat x, Contiguous arr, Element arr x, Element arr (Complex x))
=> arr (Complex x)
-> arr x
idft !a = runST $ idftInternal a
idftInternal :: forall arr x s. (RealFloat x, Contiguous arr, Element arr x, Element arr (Complex x))
=> arr (Complex x)
-> ST s (arr x)
idftInternal !a = do
let !sz = size a
getII !ix = (ix + sz `Prelude.div` 2) `Prelude.mod` sz
!mut <- new sz :: ST s (Mutable arr s x)
let go :: Int
-> Int
-> x
-> ST s ()
go !i !j !acc = if i == sz then return () else if j < sz
then do
let !jj = getII j
!atJJ@(real :+ imag) = index a jj
!sCount = fromIntegral sz
!same = (-2) * pi * (fromIntegral (i * j)) / sCount
!val = (real * cos same + imag * sin same) / sCount
go i (j + 1) val
else do
let !ii = getII i
!_ <- write mut ii acc :: ST s ()
go (i + 1) 0 0
!_ <- go 0 0 0
unsafeFreeze mut
overlapDFT :: forall arr x s. (RealFloat x, Contiguous arr, Element arr x, Element arr (Complex x))
=> Int
-> Mutable arr s (Complex x)
-> Complex x
-> Mutable arr s (Complex x)
-> ST s (Mutable arr s (Complex x))
overlapDFT n x1 x2_N_1 f1 = do
let !sz = fromIntegral n :: x
!l <- sizeMutable f1
!x1_0 <- read x1 0 :: ST s (Complex x)
let go :: Int -> ST s ()
go !ix = if ix < l
then do
f1_k <- read f1 ix
let foo' = cis (fromIntegral ix) sz
res = f1_k + x2_N_1 + x1_0
fin = foo' * res
!_ <- write f1 ix fin
go (ix + 1)
else return ()
go 0
return f1