{-# LANGUAGE ConstraintKinds       #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE GADTs                 #-}
{-# LANGUAGE MagicHash             #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PatternSynonyms       #-}
{-# LANGUAGE RebindableSyntax      #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE TypeApplications      #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE TypeSynonymInstances  #-}
{-# LANGUAGE UndecidableInstances  #-}
{-# LANGUAGE ViewPatterns          #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
-- |
-- Module      : Data.Array.Accelerate.Data.Complex
-- Copyright   : [2015..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--
-- Complex numbers, stored in the usual C-style array-of-struct representation,
-- for easy interoperability.
--
module Data.Array.Accelerate.Data.Complex (

  -- * Rectangular from
  Complex(..), pattern (::+),
  real,
  imag,

  -- * Polar form
  mkPolar,
  cis,
  polar,
  magnitude, magnitude',
  phase,

  -- * Conjugate
  conjugate,

) where

import Data.Array.Accelerate.Classes
import Data.Array.Accelerate.Data.Functor
import Data.Array.Accelerate.Pattern
import Data.Array.Accelerate.Prelude
import Data.Array.Accelerate.Representation.Tag
import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.Representation.Vec
import Data.Array.Accelerate.Smart
import Data.Array.Accelerate.Sugar.Elt
import Data.Array.Accelerate.Sugar.Vec
import Data.Array.Accelerate.Type
import Data.Primitive.Vec

import Data.Complex                                                 ( Complex(..) )
import Prelude                                                      ( ($) )
import qualified Data.Complex                                       as C
import qualified Prelude                                            as P


infix 6 ::+
pattern (::+) :: Elt a => Exp a -> Exp a -> Exp (Complex a)
pattern r ::+ i <- (deconstructComplex -> (r, i))
  where (::+) = constructComplex
{-# COMPLETE (::+) #-}


-- Use an array-of-structs representation for complex numbers if possible.
-- This matches the standard C-style layout, but we can use this representation only at
-- specific types (not for any type 'a') as we can only have vectors of primitive type.
-- For other types, we use a structure-of-arrays representation. This is handled by the
-- ComplexR. We use the GADT ComplexR and function complexR to reconstruct
-- information on how the elements are represented.
--
instance Elt a => Elt (Complex a) where
  type EltR (Complex a) = ComplexR (EltR a)
  eltR = let tR = eltR @a
          in case complexR tR of
               ComplexVec s -> TupRsingle $ VectorScalarType $ VectorType 2 s
               ComplexTup   -> TupRunit `TupRpair` tR `TupRpair` tR

  tagsR = let tR = eltR @a
           in case complexR tR of
               ComplexVec s -> [ TagRsingle (VectorScalarType (VectorType 2 s)) ]
               ComplexTup   -> let go :: TypeR t -> [TagR t]
                                   go TupRunit         = [TagRunit]
                                   go (TupRsingle s)   = [TagRsingle s]
                                   go (TupRpair ta tb) = [TagRpair a b | a <- go ta, b <- go tb]
                                in
                                [ TagRunit `TagRpair` ta `TagRpair` tb | ta <- go tR, tb <- go tR ]

  toElt = case complexR $ eltR @a of
    ComplexVec _ -> \(Vec2 r i)   -> toElt r :+ toElt i
    ComplexTup   -> \(((), r), i) -> toElt r :+ toElt i

  fromElt (r :+ i) = case complexR $ eltR @a of
    ComplexVec _ -> Vec2 (fromElt r) (fromElt i)
    ComplexTup   -> (((), fromElt r), fromElt i)

type family ComplexR a where
  ComplexR Half   = Vec2 Half
  ComplexR Float  = Vec2 Float
  ComplexR Double = Vec2 Double
  ComplexR Int    = Vec2 Int
  ComplexR Int8   = Vec2 Int8
  ComplexR Int16  = Vec2 Int16
  ComplexR Int32  = Vec2 Int32
  ComplexR Int64  = Vec2 Int64
  ComplexR Word   = Vec2 Word
  ComplexR Word8  = Vec2 Word8
  ComplexR Word16 = Vec2 Word16
  ComplexR Word32 = Vec2 Word32
  ComplexR Word64 = Vec2 Word64
  ComplexR a      = (((), a), a)

-- This isn't ideal because we gather the evidence based on the
-- representation type, so we really get the evidence (VecElt (EltR a)),
-- which is not very useful...
--    - TLM 2020-07-16
data ComplexType a c where
  ComplexVec :: VecElt a => SingleType a -> ComplexType a (Vec2 a)
  ComplexTup ::                             ComplexType a (((), a), a)

complexR :: TypeR a -> ComplexType a (ComplexR a)
complexR = tuple
  where
    tuple :: TypeR a -> ComplexType a (ComplexR a)
    tuple TupRunit       = ComplexTup
    tuple TupRpair{}     = ComplexTup
    tuple (TupRsingle s) = scalar s

    scalar :: ScalarType a -> ComplexType a (ComplexR a)
    scalar (SingleScalarType t) = single t
    scalar VectorScalarType{}   = ComplexTup

    single :: SingleType a -> ComplexType a (ComplexR a)
    single (NumSingleType t) = num t

    num :: NumType a -> ComplexType a (ComplexR a)
    num (IntegralNumType t) = integral t
    num (FloatingNumType t) = floating t

    integral :: IntegralType a -> ComplexType a (ComplexR a)
    integral TypeInt    = ComplexVec singleType
    integral TypeInt8   = ComplexVec singleType
    integral TypeInt16  = ComplexVec singleType
    integral TypeInt32  = ComplexVec singleType
    integral TypeInt64  = ComplexVec singleType
    integral TypeWord   = ComplexVec singleType
    integral TypeWord8  = ComplexVec singleType
    integral TypeWord16 = ComplexVec singleType
    integral TypeWord32 = ComplexVec singleType
    integral TypeWord64 = ComplexVec singleType

    floating :: FloatingType a -> ComplexType a (ComplexR a)
    floating TypeHalf   = ComplexVec singleType
    floating TypeFloat  = ComplexVec singleType
    floating TypeDouble = ComplexVec singleType


constructComplex :: forall a. Elt a => Exp a -> Exp a -> Exp (Complex a)
constructComplex r i =
  case complexR (eltR @a) of
    ComplexTup   -> coerce $ T2 r i
    ComplexVec _ -> V2 (coerce @a @(EltR a) r) (coerce @a @(EltR a) i)

deconstructComplex :: forall a. Elt a => Exp (Complex a) -> (Exp a, Exp a)
deconstructComplex c@(Exp c') =
  case complexR (eltR @a) of
    ComplexTup   -> let T2 r i = coerce c in (r, i)
    ComplexVec t -> let T2 r i = Exp (SmartExp (VecUnpack (VecRsucc (VecRsucc (VecRnil t))) c'))
                     in (r, i)

coerce :: EltR a ~ EltR b => Exp a -> Exp b
coerce (Exp e) = Exp e

instance (Lift Exp a, Elt (Plain a)) => Lift Exp (Complex a) where
  type Plain (Complex a) = Complex (Plain a)
  lift (r :+ i) = lift r ::+ lift i

instance Elt a => Unlift Exp (Complex (Exp a)) where
  unlift (r ::+ i) = r :+ i


instance Eq a => Eq (Complex a) where
  r1 ::+ c1 == r2 ::+ c2 = r1 == r2 && c1 == c2
  r1 ::+ c1 /= r2 ::+ c2 = r1 /= r2 || c1 /= c2

instance RealFloat a => P.Num (Exp (Complex a)) where
  (+)    = lift2 ((+) :: Complex (Exp a) -> Complex (Exp a) -> Complex (Exp a))
  (-)    = lift2 ((-) :: Complex (Exp a) -> Complex (Exp a) -> Complex (Exp a))
  (*)    = lift2 ((*) :: Complex (Exp a) -> Complex (Exp a) -> Complex (Exp a))
  negate = lift1 (negate :: Complex (Exp a) -> Complex (Exp a))
  signum z@(x ::+ y) =
    if z == 0
       then z
       else let r = magnitude z
             in x/r ::+ y/r
  abs z         = magnitude z ::+ 0
  fromInteger n = fromInteger n ::+ 0

instance RealFloat a => P.Fractional (Exp (Complex a)) where
  fromRational x  = fromRational x ::+ 0
  z / z'          = (x*x''+y*y'') / d ::+ (y*x''-x*y'') / d
    where
      x  :+ y   = unlift z
      x' :+ y'  = unlift z'
      --
      x'' = scaleFloat k x'
      y'' = scaleFloat k y'
      k   = - max (exponent x') (exponent y')
      d   = x'*x'' + y'*y''

instance RealFloat a => P.Floating (Exp (Complex a)) where
  pi                = pi ::+ 0
  exp (x ::+ y)     = let expx = exp x
                       in expx * cos y ::+ expx * sin y
  log z             = log (magnitude z) ::+ phase z
  sqrt z@(x ::+ y)  =
    if z == 0
      then 0
      else u ::+ (y < 0 ? (-v, v))
    where
      T2 u v = x < 0 ? (T2 v' u', T2 u' v')
      v'     = abs y / (u'*2)
      u'     = sqrt ((magnitude z + abs x) / 2)

  x ** y =
    if y == 0 then 1 else
    if x == 0 then if exp_r > 0 then 0 else
                   if exp_r < 0 then inf ::+ 0
                                else nan ::+ nan
              else if isInfinite r || isInfinite i
                     then if exp_r > 0 then inf ::+ 0 else
                          if exp_r < 0 then 0
                                       else nan ::+ nan
                     else exp (log x * y)
    where
      r     ::+ i  = x
      exp_r ::+ _  = y
      --
      inf = 1 / 0
      nan = 0 / 0

  sin (x ::+ y)  = sin x * cosh y ::+ cos x * sinh y
  cos (x ::+ y)  = cos x * cosh y ::+ (- sin x * sinh y)
  tan (x ::+ y)  = (sinx*coshy ::+ cosx*sinhy) / (cosx*coshy ::+ (-sinx*sinhy))
    where
      sinx  = sin x
      cosx  = cos x
      sinhy = sinh y
      coshy = cosh y

  sinh (x ::+ y) = cos y * sinh x ::+ sin  y * cosh x
  cosh (x ::+ y) = cos y * cosh x ::+ sin y * sinh x
  tanh (x ::+ y) = (cosy*sinhx ::+ siny*coshx) / (cosy*coshx ::+ siny*sinhx)
    where
      siny  = sin y
      cosy  = cos y
      sinhx = sinh x
      coshx = cosh x

  asin z@(x ::+ y) = y' ::+ (-x')
    where
      x' ::+ y' = log (((-y) ::+ x) + sqrt (1 - z*z))

  acos z                    = y'' ::+ (-x'')
    where
      x'' ::+ y''  = log (z + ((-y') ::+ x'))
      x'  ::+ y'   = sqrt (1 - z*z)

  atan z@(x ::+ y) = y' ::+ (-x')
    where
      x' ::+ y' = log (((1-y) ::+ x) / sqrt (1+z*z))

  asinh z =  log (z + sqrt (1+z*z))
  acosh z =  log (z + (z+1) * sqrt ((z-1)/(z+1)))
  atanh z =  0.5 * log ((1.0+z) / (1.0-z))


instance (FromIntegral a b, Num b, Elt (Complex b)) => FromIntegral a (Complex b) where
  fromIntegral x = fromIntegral x ::+ 0

-- | @since 1.2.0.0
--
instance Functor Complex where
  fmap f (r ::+ i) = f r ::+ f i


-- | The non-negative magnitude of a complex number
--
magnitude :: RealFloat a => Exp (Complex a) -> Exp a
magnitude (r ::+ i) = scaleFloat k (sqrt (sqr (scaleFloat mk r) + sqr (scaleFloat mk i)))
  where
    k     = max (exponent r) (exponent i)
    mk    = -k
    sqr z = z * z

-- | As 'magnitude', but ignore floating point rounding and use the traditional
-- (simpler to evaluate) definition.
--
-- @since 1.3.0.0
--
magnitude' :: RealFloat a => Exp (Complex a) -> Exp a
magnitude' (r ::+ i) = sqrt (r*r + i*i)

-- | The phase of a complex number, in the range @(-'pi', 'pi']@. If the
-- magnitude is zero, then so is the phase.
--
phase :: RealFloat a => Exp (Complex a) -> Exp a
phase z@(r ::+ i) =
  if z == 0
    then 0
    else atan2 i r

-- | The function 'polar' takes a complex number and returns a (magnitude,
-- phase) pair in canonical form: the magnitude is non-negative, and the phase
-- in the range @(-'pi', 'pi']@; if the magnitude is zero, then so is the phase.
--
polar :: RealFloat a => Exp (Complex a) -> Exp (a,a)
polar z =  T2 (magnitude z) (phase z)

-- | Form a complex number from polar components of magnitude and phase.
--
mkPolar :: forall a. Floating a => Exp a -> Exp a -> Exp (Complex a)
mkPolar = lift2 (C.mkPolar :: Exp a -> Exp a -> Complex (Exp a))

-- | @'cis' t@ is a complex value with magnitude @1@ and phase @t@ (modulo
-- @2*'pi'@).
--
cis :: forall a. Floating a => Exp a -> Exp (Complex a)
cis = lift1 (C.cis :: Exp a -> Complex (Exp a))

-- | Return the real part of a complex number
--
real :: Elt a => Exp (Complex a) -> Exp a
real (r ::+ _) = r

-- | Return the imaginary part of a complex number
--
imag :: Elt a => Exp (Complex a) -> Exp a
imag (_ ::+ i) = i

-- | Return the complex conjugate of a complex number, defined as
--
-- > conjugate(Z) = X - iY
--
conjugate :: Num a => Exp (Complex a) -> Exp (Complex a)
conjugate z = real z ::+ (- imag z)