{-# LANGUAGE DataKinds           #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE KindSignatures      #-}
{-# LANGUAGE MagicHash           #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell     #-}
{-# LANGUAGE TypeOperators       #-}
{-# OPTIONS_HADDOCK hide #-}
-- |
-- Module      : Data.Array.Accelerate.Representation.Vec
-- Copyright   : [2008..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.Representation.Vec
  where

import Data.Array.Accelerate.Type
import Data.Array.Accelerate.Representation.Type
import Data.Primitive.Vec

import Control.Monad.ST
import Data.Primitive.ByteArray
import Data.Primitive.Types
import Language.Haskell.TH

import GHC.Base                                         ( Int(..), Int#, (-#) )
import GHC.TypeNats


-- | Declares the size of a SIMD vector and the type of its elements. This
-- data type is used to denote the relation between a vector type (Vec
-- n single) with its tuple representation (tuple). Conversions between
-- those types are exposed through 'pack' and 'unpack'.
--
data VecR (n :: Nat) single tuple where
  VecRnil  :: SingleType s -> VecR 0       s ()
  VecRsucc :: VecR n s t   -> VecR (n + 1) s (t, s)

vecRvector :: KnownNat n => VecR n s tuple -> VectorType (Vec n s)
vecRvector :: VecR n s tuple -> VectorType (Vec n s)
vecRvector = (Int -> SingleType s -> VectorType (Vec n s))
-> (Int, SingleType s) -> VectorType (Vec n s)
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Int -> SingleType s -> VectorType (Vec n s)
forall (n :: Nat) a.
KnownNat n =>
Int -> SingleType a -> VectorType (Vec n a)
VectorType ((Int, SingleType s) -> VectorType (Vec n s))
-> (VecR n s tuple -> (Int, SingleType s))
-> VecR n s tuple
-> VectorType (Vec n s)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VecR n s tuple -> (Int, SingleType s)
forall (n :: Nat) s tuple. VecR n s tuple -> (Int, SingleType s)
go
  where
    go :: VecR n s tuple -> (Int, SingleType s)
    go :: VecR n s tuple -> (Int, SingleType s)
go (VecRnil SingleType s
tp)                       = (Int
0,     SingleType s
tp)
    go (VecRsucc VecR n s t
vec) | (Int
n, SingleType s
tp) <- VecR n s t -> (Int, SingleType s)
forall (n :: Nat) s tuple. VecR n s tuple -> (Int, SingleType s)
go VecR n s t
vec = (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, SingleType s
tp)

vecRtuple :: VecR n s tuple -> TypeR tuple
vecRtuple :: VecR n s tuple -> TypeR tuple
vecRtuple = (SingleType s, TypeR tuple) -> TypeR tuple
forall a b. (a, b) -> b
snd ((SingleType s, TypeR tuple) -> TypeR tuple)
-> (VecR n s tuple -> (SingleType s, TypeR tuple))
-> VecR n s tuple
-> TypeR tuple
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VecR n s tuple -> (SingleType s, TypeR tuple)
forall (n :: Nat) s tuple.
VecR n s tuple -> (SingleType s, TypeR tuple)
go
  where
    go :: VecR n s tuple -> (SingleType s, TypeR tuple)
    go :: VecR n s tuple -> (SingleType s, TypeR tuple)
go (VecRnil SingleType s
tp)                           = (SingleType s
tp, TypeR tuple
forall (s :: * -> *). TupR s ()
TupRunit)
    go (VecRsucc VecR n s t
vec) | (SingleType s
tp, TypeR t
tuple) <- VecR n s t -> (SingleType s, TypeR t)
forall (n :: Nat) s tuple.
VecR n s tuple -> (SingleType s, TypeR tuple)
go VecR n s t
vec = (SingleType s
tp, TypeR t -> TupR ScalarType s -> TupR ScalarType (t, s)
forall (s :: * -> *) a b. TupR s a -> TupR s b -> TupR s (a, b)
TupRpair TypeR t
tuple (ScalarType s -> TupR ScalarType s
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (SingleType s -> ScalarType s
forall a. SingleType a -> ScalarType a
SingleScalarType SingleType s
tp)))

pack :: forall n single tuple. KnownNat n => VecR n single tuple -> tuple -> Vec n single
pack :: VecR n single tuple -> tuple -> Vec n single
pack VecR n single tuple
vecR tuple
tuple
  | VectorType Int
n SingleType a
single <- VecR n single tuple -> VectorType (Vec n single)
forall (n :: Nat) s tuple.
KnownNat n =>
VecR n s tuple -> VectorType (Vec n s)
vecRvector VecR n single tuple
vecR
  , SingleDict a
SingleDict          <- SingleType a -> SingleDict a
forall a. SingleType a -> SingleDict a
singleDict SingleType a
single
  = (forall s. ST s (Vec n single)) -> Vec n single
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Vec n single)) -> Vec n single)
-> (forall s. ST s (Vec n single)) -> Vec n single
forall a b. (a -> b) -> a -> b
$ do
      MutableByteArray s
mba <- Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
newByteArray (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* single -> Int
forall a. Prim a => a -> Int
sizeOf (single
forall a. HasCallStack => a
undefined :: single))
      Int
-> VecR n single tuple -> tuple -> MutableByteArray s -> ST s ()
forall (n' :: Nat) tuple' s.
Prim single =>
Int
-> VecR n' single tuple' -> tuple' -> MutableByteArray s -> ST s ()
go (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) VecR n single tuple
vecR tuple
tuple MutableByteArray s
mba
      ByteArray ByteArray#
ba# <- MutableByteArray (PrimState (ST s)) -> ST s ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
unsafeFreezeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
mba
      Vec n single -> ST s (Vec n single)
forall (m :: * -> *) a. Monad m => a -> m a
return (Vec n single -> ST s (Vec n single))
-> Vec n single -> ST s (Vec n single)
forall a b. (a -> b) -> a -> b
$! ByteArray# -> Vec n single
forall (n :: Nat) a. ByteArray# -> Vec n a
Vec ByteArray#
ba#
  where
    go :: Prim single => Int -> VecR n' single tuple' -> tuple' -> MutableByteArray s -> ST s ()
    go :: Int
-> VecR n' single tuple' -> tuple' -> MutableByteArray s -> ST s ()
go Int
_ (VecRnil SingleType single
_)  ()      MutableByteArray s
_   = () -> ST s ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    go Int
i (VecRsucc VecR n single t
r) (xs, x) MutableByteArray s
mba = do
      MutableByteArray (PrimState (ST s)) -> Int -> single -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
mba Int
i single
x
      Int -> VecR n single t -> t -> MutableByteArray s -> ST s ()
forall (n' :: Nat) tuple' s.
Prim single =>
Int
-> VecR n' single tuple' -> tuple' -> MutableByteArray s -> ST s ()
go (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) VecR n single t
r t
xs MutableByteArray s
mba

unpack :: forall n single tuple. KnownNat n => VecR n single tuple -> Vec n single -> tuple
unpack :: VecR n single tuple -> Vec n single -> tuple
unpack VecR n single tuple
vecR (Vec ByteArray#
ba#)
  | VectorType Int
n SingleType a
single <- VecR n single tuple -> VectorType (Vec n single)
forall (n :: Nat) s tuple.
KnownNat n =>
VecR n s tuple -> VectorType (Vec n s)
vecRvector VecR n single tuple
vecR
  , (I# Int#
n#)             <- Int
n
  , SingleDict a
SingleDict          <- SingleType a -> SingleDict a
forall a. SingleType a -> SingleDict a
singleDict SingleType a
single
  = Int# -> VecR n single tuple -> tuple
forall (n' :: Nat) tuple'.
Prim single =>
Int# -> VecR n' single tuple' -> tuple'
go (Int#
n# Int# -> Int# -> Int#
-# Int#
1#) VecR n single tuple
vecR
  where
    go :: Prim single => Int# -> VecR n' single tuple' -> tuple'
    go :: Int# -> VecR n' single tuple' -> tuple'
go Int#
_  (VecRnil SingleType single
_)  = ()
    go Int#
i# (VecRsucc VecR n single t
r) = single
x single -> (t, single) -> (t, single)
`seq` t
xs t -> (t, single) -> (t, single)
`seq` (t
xs, single
x)
      where
        xs :: t
xs = Int# -> VecR n single t -> t
forall (n' :: Nat) tuple'.
Prim single =>
Int# -> VecR n' single tuple' -> tuple'
go (Int#
i# Int# -> Int# -> Int#
-# Int#
1#) VecR n single t
r
        x :: single
x  = ByteArray# -> Int# -> single
forall a. Prim a => ByteArray# -> Int# -> a
indexByteArray# ByteArray#
ba# Int#
i#

rnfVecR :: VecR n single tuple -> ()
rnfVecR :: VecR n single tuple -> ()
rnfVecR (VecRnil SingleType single
tp)   = SingleType single -> ()
forall t. SingleType t -> ()
rnfSingleType SingleType single
tp
rnfVecR (VecRsucc VecR n single t
vec) = VecR n single t -> ()
forall (n :: Nat) single tuple. VecR n single tuple -> ()
rnfVecR VecR n single t
vec

liftVecR :: VecR n single tuple -> Q (TExp (VecR n single tuple))
liftVecR :: VecR n single tuple -> Q (TExp (VecR n single tuple))
liftVecR (VecRnil SingleType single
tp)   = [|| VecRnil $$(liftSingleType tp) ||]
liftVecR (VecRsucc VecR n single t
vec) = [|| VecRsucc $$(liftVecR vec) ||]