{-# LANGUAGE BangPatterns, FlexibleContexts #-}
module Statistics.Transform
(
CD
, dct
, dct_
, idct
, idct_
, fft
, ifft
) where
import Control.Monad (when)
import Control.Monad.ST (ST)
import Data.Bits (shiftL, shiftR)
import Data.Complex (Complex(..), conjugate, realPart)
import Numeric.SpecFunctions (log2)
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Generic.Mutable as M
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector as V
type CD = Complex Double
dct :: (G.Vector v CD, G.Vector v Double, G.Vector v Int) => v Double -> v Double
dct :: forall (v :: * -> *).
(Vector v CD, Vector v Double, Vector v Int) =>
v Double -> v Double
dct = forall (v :: * -> *).
(Vector v CD, Vector v Double, Vector v Int) =>
v CD -> v Double
dctWorker forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
G.map (forall a. a -> a -> Complex a
:+Double
0)
{-# INLINABLE dct #-}
{-# SPECIAlIZE dct :: U.Vector Double -> U.Vector Double #-}
{-# SPECIAlIZE dct :: V.Vector Double -> V.Vector Double #-}
dct_ :: (G.Vector v CD, G.Vector v Double, G.Vector v Int) => v CD -> v Double
dct_ :: forall (v :: * -> *).
(Vector v CD, Vector v Double, Vector v Int) =>
v CD -> v Double
dct_ = forall (v :: * -> *).
(Vector v CD, Vector v Double, Vector v Int) =>
v CD -> v Double
dctWorker forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
G.map (\(Double
i :+ Double
_) -> Double
i forall a. a -> a -> Complex a
:+ Double
0)
{-# INLINABLE dct_ #-}
{-# SPECIAlIZE dct_ :: U.Vector CD -> U.Vector Double #-}
{-# SPECIAlIZE dct_ :: V.Vector CD -> V.Vector Double#-}
dctWorker :: (G.Vector v CD, G.Vector v Double, G.Vector v Int) => v CD -> v Double
{-# INLINE dctWorker #-}
dctWorker :: forall (v :: * -> *).
(Vector v CD, Vector v Double, Vector v Int) =>
v CD -> v Double
dctWorker v CD
xs
| forall (v :: * -> *) a. Vector v a => v a -> Int
G.length v CD
xs forall a. Eq a => a -> a -> Bool
== Int
1 = forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
G.map ((Double
2forall a. Num a => a -> a -> a
*) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Complex a -> a
realPart) v CD
xs
| forall (v :: * -> *) a. Vector v a => v a -> Bool
vectorOK v CD
xs = forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
G.map forall a. Complex a -> a
realPart forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a b c.
(Vector v a, Vector v b, Vector v c) =>
(a -> b -> c) -> v a -> v b -> v c
G.zipWith forall a. Num a => a -> a -> a
(*) v CD
weights (forall (v :: * -> *). Vector v CD => v CD -> v CD
fft v CD
interleaved)
| Bool
otherwise = forall a. HasCallStack => [Char] -> a
error [Char]
"Statistics.Transform.dct: bad vector length"
where
interleaved :: v CD
interleaved = forall (v :: * -> *) a.
(HasCallStack, Vector v a, Vector v Int) =>
v a -> v Int -> v a
G.backpermute v CD
xs forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a. (Vector v a, Enum a) => a -> a -> a -> v a
G.enumFromThenTo Int
0 Int
2 (Int
lenforall a. Num a => a -> a -> a
-Int
2) forall (v :: * -> *) a. Vector v a => v a -> v a -> v a
G.++
forall (v :: * -> *) a. (Vector v a, Enum a) => a -> a -> a -> v a
G.enumFromThenTo (Int
lenforall a. Num a => a -> a -> a
-Int
1) (Int
lenforall a. Num a => a -> a -> a
-Int
3) Int
1
weights :: v CD
weights = forall (v :: * -> *) a. Vector v a => a -> v a -> v a
G.cons CD
2 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a. Vector v a => Int -> (Int -> a) -> v a
G.generate (Int
lenforall a. Num a => a -> a -> a
-Int
1) forall a b. (a -> b) -> a -> b
$ \Int
x ->
CD
2 forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
exp ((Double
0forall a. a -> a -> Complex a
:+(-Double
1))forall a. Num a => a -> a -> a
*Int -> CD
fi (Int
xforall a. Num a => a -> a -> a
+Int
1)forall a. Num a => a -> a -> a
*forall a. Floating a => a
piforall a. Fractional a => a -> a -> a
/(CD
2forall a. Num a => a -> a -> a
*CD
n))
where n :: CD
n = Int -> CD
fi Int
len
len :: Int
len = forall (v :: * -> *) a. Vector v a => v a -> Int
G.length v CD
xs
idct :: (G.Vector v CD, G.Vector v Double) => v Double -> v Double
idct :: forall (v :: * -> *).
(Vector v CD, Vector v Double) =>
v Double -> v Double
idct = forall (v :: * -> *).
(Vector v CD, Vector v Double) =>
v CD -> v Double
idctWorker forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
G.map (forall a. a -> a -> Complex a
:+Double
0)
{-# INLINABLE idct #-}
{-# SPECIAlIZE idct :: U.Vector Double -> U.Vector Double #-}
{-# SPECIAlIZE idct :: V.Vector Double -> V.Vector Double #-}
idct_ :: (G.Vector v CD, G.Vector v Double) => v CD -> v Double
idct_ :: forall (v :: * -> *).
(Vector v CD, Vector v Double) =>
v CD -> v Double
idct_ = forall (v :: * -> *).
(Vector v CD, Vector v Double) =>
v CD -> v Double
idctWorker forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
G.map (\(Double
i :+ Double
_) -> Double
i forall a. a -> a -> Complex a
:+ Double
0)
{-# INLINABLE idct_ #-}
{-# SPECIAlIZE idct_ :: U.Vector CD -> U.Vector Double #-}
{-# SPECIAlIZE idct_ :: V.Vector CD -> V.Vector Double #-}
idctWorker :: (G.Vector v CD, G.Vector v Double) => v CD -> v Double
{-# INLINE idctWorker #-}
idctWorker :: forall (v :: * -> *).
(Vector v CD, Vector v Double) =>
v CD -> v Double
idctWorker v CD
xs
| forall (v :: * -> *) a. Vector v a => v a -> Bool
vectorOK v CD
xs = forall (v :: * -> *) a. Vector v a => Int -> (Int -> a) -> v a
G.generate Int
len Int -> Double
interleave
| Bool
otherwise = forall a. HasCallStack => [Char] -> a
error [Char]
"Statistics.Transform.dct: bad vector length"
where
interleave :: Int -> Double
interleave Int
z | forall a. Integral a => a -> Bool
even Int
z = v Double
vals forall (v :: * -> *) a. Vector v a => v a -> Int -> a
`G.unsafeIndex` Int -> Int
halve Int
z
| Bool
otherwise = v Double
vals forall (v :: * -> *) a. Vector v a => v a -> Int -> a
`G.unsafeIndex` (Int
len forall a. Num a => a -> a -> a
- Int -> Int
halve Int
z forall a. Num a => a -> a -> a
- Int
1)
vals :: v Double
vals = forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
G.map forall a. Complex a -> a
realPart forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *). Vector v CD => v CD -> v CD
ifft forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a b c.
(Vector v a, Vector v b, Vector v c) =>
(a -> b -> c) -> v a -> v b -> v c
G.zipWith forall a. Num a => a -> a -> a
(*) v CD
weights v CD
xs
weights :: v CD
weights
= forall (v :: * -> *) a. Vector v a => a -> v a -> v a
G.cons CD
n
forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a. Vector v a => Int -> (Int -> a) -> v a
G.generate (Int
len forall a. Num a => a -> a -> a
- Int
1) forall a b. (a -> b) -> a -> b
$ \Int
x -> CD
2 forall a. Num a => a -> a -> a
* CD
n forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
exp ((Double
0forall a. a -> a -> Complex a
:+Double
1) forall a. Num a => a -> a -> a
* Int -> CD
fi (Int
xforall a. Num a => a -> a -> a
+Int
1) forall a. Num a => a -> a -> a
* forall a. Floating a => a
piforall a. Fractional a => a -> a -> a
/(CD
2forall a. Num a => a -> a -> a
*CD
n))
where n :: CD
n = Int -> CD
fi Int
len
len :: Int
len = forall (v :: * -> *) a. Vector v a => v a -> Int
G.length v CD
xs
ifft :: G.Vector v CD => v CD -> v CD
ifft :: forall (v :: * -> *). Vector v CD => v CD -> v CD
ifft v CD
xs
| forall (v :: * -> *) a. Vector v a => v a -> Bool
vectorOK v CD
xs = forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
G.map ((forall a. Fractional a => a -> a -> a
/Int -> CD
fi (forall (v :: * -> *) a. Vector v a => v a -> Int
G.length v CD
xs)) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Num a => Complex a -> Complex a
conjugate) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *). Vector v CD => v CD -> v CD
fft forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
G.map forall a. Num a => Complex a -> Complex a
conjugate forall a b. (a -> b) -> a -> b
$ v CD
xs
| Bool
otherwise = forall a. HasCallStack => [Char] -> a
error [Char]
"Statistics.Transform.ifft: bad vector length"
{-# INLINABLE ifft #-}
{-# SPECIAlIZE ifft :: U.Vector CD -> U.Vector CD #-}
{-# SPECIAlIZE ifft :: V.Vector CD -> V.Vector CD #-}
fft :: G.Vector v CD => v CD -> v CD
fft :: forall (v :: * -> *). Vector v CD => v CD -> v CD
fft v CD
v | forall (v :: * -> *) a. Vector v a => v a -> Bool
vectorOK v CD
v = forall (v :: * -> *) a.
Vector v a =>
(forall s. ST s (Mutable v s a)) -> v a
G.create forall a b. (a -> b) -> a -> b
$ do Mutable v s CD
mv <- forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
v a -> m (Mutable v (PrimState m) a)
G.thaw v CD
v
forall (v :: * -> * -> *) s. MVector v CD => v s CD -> ST s ()
mfft Mutable v s CD
mv
forall (m :: * -> *) a. Monad m => a -> m a
return Mutable v s CD
mv
| Bool
otherwise = forall a. HasCallStack => [Char] -> a
error [Char]
"Statistics.Transform.fft: bad vector length"
{-# INLINABLE fft #-}
{-# SPECIAlIZE fft :: U.Vector CD -> U.Vector CD #-}
{-# SPECIAlIZE fft :: V.Vector CD -> V.Vector CD #-}
mfft :: (M.MVector v CD) => v s CD -> ST s ()
{-# INLINE mfft #-}
mfft :: forall (v :: * -> * -> *) s. MVector v CD => v s CD -> ST s ()
mfft v s CD
vec = Int -> Int -> ST s ()
bitReverse Int
0 Int
0
where
bitReverse :: Int -> Int -> ST s ()
bitReverse Int
i Int
j | Int
i forall a. Eq a => a -> a -> Bool
== Int
lenforall a. Num a => a -> a -> a
-Int
1 = Int -> Int -> ST s ()
stage Int
0 Int
1
| Bool
otherwise = do
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
i forall a. Ord a => a -> a -> Bool
< Int
j) forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> Int -> m ()
M.swap v s CD
vec Int
i Int
j
let inner :: Int -> Int -> ST s ()
inner Int
k Int
l | Int
k forall a. Ord a => a -> a -> Bool
<= Int
l = Int -> Int -> ST s ()
inner (Int
k forall a. Bits a => a -> Int -> a
`shiftR` Int
1) (Int
lforall a. Num a => a -> a -> a
-Int
k)
| Bool
otherwise = Int -> Int -> ST s ()
bitReverse (Int
iforall a. Num a => a -> a -> a
+Int
1) (Int
lforall a. Num a => a -> a -> a
+Int
k)
Int -> Int -> ST s ()
inner (Int
len forall a. Bits a => a -> Int -> a
`shiftR` Int
1) Int
j
stage :: Int -> Int -> ST s ()
stage Int
l !Int
l1 | Int
l forall a. Eq a => a -> a -> Bool
== Int
m = forall (m :: * -> *) a. Monad m => a -> m a
return ()
| Bool
otherwise = do
let !l2 :: Int
l2 = Int
l1 forall a. Bits a => a -> Int -> a
`shiftL` Int
1
!e :: Double
e = -Double
6.283185307179586forall a. Fractional a => a -> a -> a
/forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
l2
flight :: Int -> Double -> ST s ()
flight Int
j !Double
a | Int
j forall a. Eq a => a -> a -> Bool
== Int
l1 = Int -> Int -> ST s ()
stage (Int
lforall a. Num a => a -> a -> a
+Int
1) Int
l2
| Bool
otherwise = do
let butterfly :: Int -> ST s ()
butterfly Int
i | Int
i forall a. Ord a => a -> a -> Bool
>= Int
len = Int -> Double -> ST s ()
flight (Int
jforall a. Num a => a -> a -> a
+Int
1) (Double
aforall a. Num a => a -> a -> a
+Double
e)
| Bool
otherwise = do
let i1 :: Int
i1 = Int
i forall a. Num a => a -> a -> a
+ Int
l1
Double
xi1 :+ Double
yi1 <- forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
M.read v s CD
vec Int
i1
let !c :: Double
c = forall a. Floating a => a -> a
cos Double
a
!s :: Double
s = forall a. Floating a => a -> a
sin Double
a
d :: CD
d = (Double
cforall a. Num a => a -> a -> a
*Double
xi1 forall a. Num a => a -> a -> a
- Double
sforall a. Num a => a -> a -> a
*Double
yi1) forall a. a -> a -> Complex a
:+ (Double
sforall a. Num a => a -> a -> a
*Double
xi1 forall a. Num a => a -> a -> a
+ Double
cforall a. Num a => a -> a -> a
*Double
yi1)
CD
ci <- forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
M.read v s CD
vec Int
i
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
M.write v s CD
vec Int
i1 (CD
ci forall a. Num a => a -> a -> a
- CD
d)
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
M.write v s CD
vec Int
i (CD
ci forall a. Num a => a -> a -> a
+ CD
d)
Int -> ST s ()
butterfly (Int
iforall a. Num a => a -> a -> a
+Int
l2)
Int -> ST s ()
butterfly Int
j
Int -> Double -> ST s ()
flight Int
0 Double
0
len :: Int
len = forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
M.length v s CD
vec
m :: Int
m = Int -> Int
log2 Int
len
fi :: Int -> CD
fi :: Int -> CD
fi = forall a b. (Integral a, Num b) => a -> b
fromIntegral
halve :: Int -> Int
halve :: Int -> Int
halve = (forall a. Bits a => a -> Int -> a
`shiftR` Int
1)
vectorOK :: G.Vector v a => v a -> Bool
{-# INLINE vectorOK #-}
vectorOK :: forall (v :: * -> *) a. Vector v a => v a -> Bool
vectorOK v a
v = (Int
1 forall a. Bits a => a -> Int -> a
`shiftL` Int -> Int
log2 Int
n) forall a. Eq a => a -> a -> Bool
== Int
n where n :: Int
n = forall (v :: * -> *) a. Vector v a => v a -> Int
G.length v a
v