{-# LANGUAGE CPP #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ViewPatterns #-}
module Data.Array.Accelerate.IO.Data.ByteString
where
import Data.Array.Accelerate.Array.Data
import Data.Array.Accelerate.Array.Unique
import Data.Array.Accelerate.Lifetime
import Data.Array.Accelerate.Sugar.Array
import Data.Array.Accelerate.Sugar.Elt
import Data.Array.Accelerate.Sugar.Shape
import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.Type
import qualified Data.Array.Accelerate.Representation.Array as R
import qualified Data.Array.Accelerate.Representation.Shape as R
import Data.Primitive.Vec
import Data.ByteString as B
import Data.ByteString.Internal as B
import Foreign.ForeignPtr
import Foreign.Storable
import System.IO.Unsafe
#if !MIN_VERSION_base(4,10,0)
import GHC.ForeignPtr
#endif
type family ByteStrings e where
ByteStrings () = ()
ByteStrings Int = ByteString
ByteStrings Int8 = ByteString
ByteStrings Int16 = ByteString
ByteStrings Int32 = ByteString
ByteStrings Int64 = ByteString
ByteStrings Word = ByteString
ByteStrings Word8 = ByteString
ByteStrings Word16 = ByteString
ByteStrings Word32 = ByteString
ByteStrings Word64 = ByteString
ByteStrings Half = ByteString
ByteStrings Float = ByteString
ByteStrings Double = ByteString
ByteStrings Bool = ByteString
ByteStrings Char = ByteString
ByteStrings (Vec n a) = ByteStrings a
ByteStrings (a,b) = (ByteStrings a, ByteStrings b)
{-# INLINE fromByteStrings #-}
fromByteStrings :: forall sh e. (Shape sh, Elt e) => sh -> ByteStrings (EltR e) -> Array sh e
fromByteStrings sh bs = Array (R.Array (fromElt sh) (tuple (eltR @e) bs))
where
wrap :: ByteString -> UniqueArray a
wrap (B.toForeignPtr -> (ps,s,_)) =
unsafePerformIO $ newUniqueArray (castForeignPtr (plusForeignPtr ps s))
tuple :: TypeR a -> ByteStrings a -> ArrayData a
tuple TupRunit () = ()
tuple (TupRpair aR1 aR2) (a1, a2) = (tuple aR1 a1, tuple aR2 a2)
tuple (TupRsingle t) a = scalar t a
scalar :: ScalarType a -> ByteStrings a -> ArrayData a
scalar (SingleScalarType t) = single t
scalar (VectorScalarType t) = vector t
vector :: VectorType a -> ByteStrings a -> ArrayData a
vector (VectorType _ t)
| SingleArrayDict <- singleArrayDict t
= single t
single :: SingleType a -> ByteStrings a -> ArrayData a
single (NumSingleType t) = num t
num :: NumType a -> ByteStrings a -> ArrayData a
num (IntegralNumType t) = integral t
num (FloatingNumType t) = floating t
integral :: IntegralType a -> ByteStrings a -> ArrayData a
integral TypeInt = wrap
integral TypeInt8 = wrap
integral TypeInt16 = wrap
integral TypeInt32 = wrap
integral TypeInt64 = wrap
integral TypeWord = wrap
integral TypeWord8 = wrap
integral TypeWord16 = wrap
integral TypeWord32 = wrap
integral TypeWord64 = wrap
floating :: FloatingType a -> ByteStrings a -> ArrayData a
floating TypeHalf = wrap
floating TypeFloat = wrap
floating TypeDouble = wrap
{-# INLINE toByteStrings #-}
toByteStrings :: forall sh e. (Shape sh, Elt e) => Array sh e -> ByteStrings (EltR e)
toByteStrings (Array (R.Array sh adata)) = tuple (eltR @e) adata
where
wrap :: forall a. Storable a => UniqueArray a -> Int -> ByteString
wrap (unsafeGetValue . uniqueArrayData -> fp) k =
B.fromForeignPtr (castForeignPtr fp) 0 (R.size (shapeR @sh) sh * k * sizeOf (undefined::a))
tuple :: TypeR a -> ArrayData a -> ByteStrings a
tuple TupRunit () = ()
tuple (TupRpair aR1 aR2) (a1, a2) = (tuple aR1 a1, tuple aR2 a2)
tuple (TupRsingle t) a = scalar t a 1
scalar :: ScalarType a -> ArrayData a -> Int -> ByteStrings a
scalar (SingleScalarType t) = single t
scalar (VectorScalarType t) = vector t
vector :: VectorType a -> ArrayData a -> Int -> ByteStrings a
vector (VectorType w t) a _
| SingleArrayDict <- singleArrayDict t
= single t a w
single :: SingleType a -> ArrayData a -> Int -> ByteStrings a
single (NumSingleType t) = num t
num :: NumType a -> ArrayData a -> Int -> ByteStrings a
num (IntegralNumType t) = integral t
num (FloatingNumType t) = floating t
integral :: IntegralType a -> ArrayData a -> Int -> ByteStrings a
integral TypeInt = wrap
integral TypeInt8 = wrap
integral TypeInt16 = wrap
integral TypeInt32 = wrap
integral TypeInt64 = wrap
integral TypeWord = wrap
integral TypeWord8 = wrap
integral TypeWord16 = wrap
integral TypeWord32 = wrap
integral TypeWord64 = wrap
floating :: FloatingType a -> ArrayData a -> Int -> ByteStrings a
floating TypeHalf = wrap
floating TypeFloat = wrap
floating TypeDouble = wrap
#if !MIN_VERSION_base(4,10,0)
plusForeignPtr :: ForeignPtr a -> Int -> ForeignPtr b
plusForeignPtr (ForeignPtr addr# c) (I# d#) = ForeignPtr (plusAddr# addr# d#) c
#endif