{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RebindableSyntax #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Data.Array.Accelerate.Data.Complex (
Complex(..), pattern (::+),
real,
imag,
mkPolar,
cis,
polar,
magnitude, magnitude',
phase,
conjugate,
) where
import Data.Array.Accelerate.Classes
import Data.Array.Accelerate.Data.Functor
import Data.Array.Accelerate.Pattern
import Data.Array.Accelerate.Prelude
import Data.Array.Accelerate.Representation.Tag
import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.Representation.Vec
import Data.Array.Accelerate.Smart
import Data.Array.Accelerate.Sugar.Elt
import Data.Array.Accelerate.Sugar.Vec
import Data.Array.Accelerate.Type
import Data.Primitive.Vec
import Data.Complex ( Complex(..) )
import Prelude ( ($) )
import qualified Data.Complex as C
import qualified Prelude as P
infix 6 ::+
pattern (::+) :: Elt a => Exp a -> Exp a -> Exp (Complex a)
pattern r ::+ i <- (deconstructComplex -> (r, i))
where (::+) = constructComplex
{-# COMPLETE (::+) #-}
instance Elt a => Elt (Complex a) where
type EltR (Complex a) = ComplexR (EltR a)
eltR = let tR = eltR @a
in case complexR tR of
ComplexVec s -> TupRsingle $ VectorScalarType $ VectorType 2 s
ComplexTup -> TupRunit `TupRpair` tR `TupRpair` tR
tagsR = let tR = eltR @a
in case complexR tR of
ComplexVec s -> [ TagRsingle (VectorScalarType (VectorType 2 s)) ]
ComplexTup -> let go :: TypeR t -> [TagR t]
go TupRunit = [TagRunit]
go (TupRsingle s) = [TagRsingle s]
go (TupRpair ta tb) = [TagRpair a b | a <- go ta, b <- go tb]
in
[ TagRunit `TagRpair` ta `TagRpair` tb | ta <- go tR, tb <- go tR ]
toElt = case complexR $ eltR @a of
ComplexVec _ -> \(Vec2 r i) -> toElt r :+ toElt i
ComplexTup -> \(((), r), i) -> toElt r :+ toElt i
fromElt (r :+ i) = case complexR $ eltR @a of
ComplexVec _ -> Vec2 (fromElt r) (fromElt i)
ComplexTup -> (((), fromElt r), fromElt i)
type family ComplexR a where
ComplexR Half = Vec2 Half
ComplexR Float = Vec2 Float
ComplexR Double = Vec2 Double
ComplexR Int = Vec2 Int
ComplexR Int8 = Vec2 Int8
ComplexR Int16 = Vec2 Int16
ComplexR Int32 = Vec2 Int32
ComplexR Int64 = Vec2 Int64
ComplexR Word = Vec2 Word
ComplexR Word8 = Vec2 Word8
ComplexR Word16 = Vec2 Word16
ComplexR Word32 = Vec2 Word32
ComplexR Word64 = Vec2 Word64
ComplexR a = (((), a), a)
data ComplexType a c where
ComplexVec :: VecElt a => SingleType a -> ComplexType a (Vec2 a)
ComplexTup :: ComplexType a (((), a), a)
complexR :: TypeR a -> ComplexType a (ComplexR a)
complexR = tuple
where
tuple :: TypeR a -> ComplexType a (ComplexR a)
tuple TupRunit = ComplexTup
tuple TupRpair{} = ComplexTup
tuple (TupRsingle s) = scalar s
scalar :: ScalarType a -> ComplexType a (ComplexR a)
scalar (SingleScalarType t) = single t
scalar VectorScalarType{} = ComplexTup
single :: SingleType a -> ComplexType a (ComplexR a)
single (NumSingleType t) = num t
num :: NumType a -> ComplexType a (ComplexR a)
num (IntegralNumType t) = integral t
num (FloatingNumType t) = floating t
integral :: IntegralType a -> ComplexType a (ComplexR a)
integral TypeInt = ComplexVec singleType
integral TypeInt8 = ComplexVec singleType
integral TypeInt16 = ComplexVec singleType
integral TypeInt32 = ComplexVec singleType
integral TypeInt64 = ComplexVec singleType
integral TypeWord = ComplexVec singleType
integral TypeWord8 = ComplexVec singleType
integral TypeWord16 = ComplexVec singleType
integral TypeWord32 = ComplexVec singleType
integral TypeWord64 = ComplexVec singleType
floating :: FloatingType a -> ComplexType a (ComplexR a)
floating TypeHalf = ComplexVec singleType
floating TypeFloat = ComplexVec singleType
floating TypeDouble = ComplexVec singleType
constructComplex :: forall a. Elt a => Exp a -> Exp a -> Exp (Complex a)
constructComplex r i =
case complexR (eltR @a) of
ComplexTup -> coerce $ T2 r i
ComplexVec _ -> V2 (coerce @a @(EltR a) r) (coerce @a @(EltR a) i)
deconstructComplex :: forall a. Elt a => Exp (Complex a) -> (Exp a, Exp a)
deconstructComplex c@(Exp c') =
case complexR (eltR @a) of
ComplexTup -> let T2 r i = coerce c in (r, i)
ComplexVec t -> let T2 r i = Exp (SmartExp (VecUnpack (VecRsucc (VecRsucc (VecRnil t))) c'))
in (r, i)
coerce :: EltR a ~ EltR b => Exp a -> Exp b
coerce (Exp e) = Exp e
instance (Lift Exp a, Elt (Plain a)) => Lift Exp (Complex a) where
type Plain (Complex a) = Complex (Plain a)
lift (r :+ i) = lift r ::+ lift i
instance Elt a => Unlift Exp (Complex (Exp a)) where
unlift (r ::+ i) = r :+ i
instance Eq a => Eq (Complex a) where
r1 ::+ c1 == r2 ::+ c2 = r1 == r2 && c1 == c2
r1 ::+ c1 /= r2 ::+ c2 = r1 /= r2 || c1 /= c2
instance RealFloat a => P.Num (Exp (Complex a)) where
(+) = lift2 ((+) :: Complex (Exp a) -> Complex (Exp a) -> Complex (Exp a))
(-) = lift2 ((-) :: Complex (Exp a) -> Complex (Exp a) -> Complex (Exp a))
(*) = lift2 ((*) :: Complex (Exp a) -> Complex (Exp a) -> Complex (Exp a))
negate = lift1 (negate :: Complex (Exp a) -> Complex (Exp a))
signum z@(x ::+ y) =
if z == 0
then z
else let r = magnitude z
in x/r ::+ y/r
abs z = magnitude z ::+ 0
fromInteger n = fromInteger n ::+ 0
instance RealFloat a => P.Fractional (Exp (Complex a)) where
fromRational x = fromRational x ::+ 0
z / z' = (x*x''+y*y'') / d ::+ (y*x''-x*y'') / d
where
x :+ y = unlift z
x' :+ y' = unlift z'
x'' = scaleFloat k x'
y'' = scaleFloat k y'
k = - max (exponent x') (exponent y')
d = x'*x'' + y'*y''
instance RealFloat a => P.Floating (Exp (Complex a)) where
pi = pi ::+ 0
exp (x ::+ y) = let expx = exp x
in expx * cos y ::+ expx * sin y
log z = log (magnitude z) ::+ phase z
sqrt z@(x ::+ y) =
if z == 0
then 0
else u ::+ (y < 0 ? (-v, v))
where
T2 u v = x < 0 ? (T2 v' u', T2 u' v')
v' = abs y / (u'*2)
u' = sqrt ((magnitude z + abs x) / 2)
x ** y =
if y == 0 then 1 else
if x == 0 then if exp_r > 0 then 0 else
if exp_r < 0 then inf ::+ 0
else nan ::+ nan
else if isInfinite r || isInfinite i
then if exp_r > 0 then inf ::+ 0 else
if exp_r < 0 then 0
else nan ::+ nan
else exp (log x * y)
where
r ::+ i = x
exp_r ::+ _ = y
inf = 1 / 0
nan = 0 / 0
sin (x ::+ y) = sin x * cosh y ::+ cos x * sinh y
cos (x ::+ y) = cos x * cosh y ::+ (- sin x * sinh y)
tan (x ::+ y) = (sinx*coshy ::+ cosx*sinhy) / (cosx*coshy ::+ (-sinx*sinhy))
where
sinx = sin x
cosx = cos x
sinhy = sinh y
coshy = cosh y
sinh (x ::+ y) = cos y * sinh x ::+ sin y * cosh x
cosh (x ::+ y) = cos y * cosh x ::+ sin y * sinh x
tanh (x ::+ y) = (cosy*sinhx ::+ siny*coshx) / (cosy*coshx ::+ siny*sinhx)
where
siny = sin y
cosy = cos y
sinhx = sinh x
coshx = cosh x
asin z@(x ::+ y) = y' ::+ (-x')
where
x' ::+ y' = log (((-y) ::+ x) + sqrt (1 - z*z))
acos z = y'' ::+ (-x'')
where
x'' ::+ y'' = log (z + ((-y') ::+ x'))
x' ::+ y' = sqrt (1 - z*z)
atan z@(x ::+ y) = y' ::+ (-x')
where
x' ::+ y' = log (((1-y) ::+ x) / sqrt (1+z*z))
asinh z = log (z + sqrt (1+z*z))
acosh z = log (z + (z+1) * sqrt ((z-1)/(z+1)))
atanh z = 0.5 * log ((1.0+z) / (1.0-z))
instance (FromIntegral a b, Num b, Elt (Complex b)) => FromIntegral a (Complex b) where
fromIntegral x = fromIntegral x ::+ 0
instance Functor Complex where
fmap f (r ::+ i) = f r ::+ f i
magnitude :: RealFloat a => Exp (Complex a) -> Exp a
magnitude (r ::+ i) = scaleFloat k (sqrt (sqr (scaleFloat mk r) + sqr (scaleFloat mk i)))
where
k = max (exponent r) (exponent i)
mk = -k
sqr z = z * z
magnitude' :: RealFloat a => Exp (Complex a) -> Exp a
magnitude' (r ::+ i) = sqrt (r*r + i*i)
phase :: RealFloat a => Exp (Complex a) -> Exp a
phase z@(r ::+ i) =
if z == 0
then 0
else atan2 i r
polar :: RealFloat a => Exp (Complex a) -> Exp (a,a)
polar z = T2 (magnitude z) (phase z)
mkPolar :: forall a. Floating a => Exp a -> Exp a -> Exp (Complex a)
mkPolar = lift2 (C.mkPolar :: Exp a -> Exp a -> Complex (Exp a))
cis :: forall a. Floating a => Exp a -> Exp (Complex a)
cis = lift1 (C.cis :: Exp a -> Complex (Exp a))
real :: Elt a => Exp (Complex a) -> Exp a
real (r ::+ _) = r
imag :: Elt a => Exp (Complex a) -> Exp a
imag (_ ::+ i) = i
conjugate :: Num a => Exp (Complex a) -> Exp (Complex a)
conjugate z = real z ::+ (- imag z)