{-# language BangPatterns #-}
{-# language MagicHash #-}
{-# language RankNTypes #-}
{-# language ScopedTypeVariables #-}
{-# language TypeFamilies #-}
{-# language UnboxedTuples #-}
module Data.Primitive.Unlifted.Array
(
UnliftedArray(..)
, MutableUnliftedArray(..)
, newUnliftedArray
, unsafeNewUnliftedArray
, sizeofUnliftedArray
, sizeofMutableUnliftedArray
, sameMutableUnliftedArray
, writeUnliftedArray
, readUnliftedArray
, indexUnliftedArray
, unsafeFreezeUnliftedArray
, freezeUnliftedArray
, thawUnliftedArray
, setUnliftedArray
, copyUnliftedArray
, copyMutableUnliftedArray
, cloneUnliftedArray
, cloneMutableUnliftedArray
, emptyUnliftedArray
, runUnliftedArray
, unliftedArrayToList
, unliftedArrayFromList
, unliftedArrayFromListN
, foldrUnliftedArray
, foldrUnliftedArray'
, foldlUnliftedArray
, foldlUnliftedArray'
, foldlUnliftedArrayM'
, traverseUnliftedArray_
, itraverseUnliftedArray_
, mapUnliftedArray
) where
import Control.Monad.Primitive (PrimMonad,PrimState,primitive,primitive_)
import Control.Monad.ST (ST)
import Data.Primitive.Unlifted.Class (PrimUnlifted)
import GHC.Exts (Int(I#),MutableArrayArray#,ArrayArray#,State#)
import qualified Data.List as L
import qualified Data.Primitive.Unlifted.Class as C
import qualified GHC.Exts as Exts
import qualified GHC.ST as ST
data MutableUnliftedArray s a
= MutableUnliftedArray (MutableArrayArray# s)
data UnliftedArray a
= UnliftedArray ArrayArray#
unsafeNewUnliftedArray
:: (PrimMonad m)
=> Int
-> m (MutableUnliftedArray (PrimState m) a)
{-# inline unsafeNewUnliftedArray #-}
unsafeNewUnliftedArray (I# i#) = primitive $ \s -> case Exts.newArrayArray# i# s of
(# s', maa# #) -> (# s', MutableUnliftedArray maa# #)
newUnliftedArray
:: (PrimMonad m, PrimUnlifted a)
=> Int
-> a
-> m (MutableUnliftedArray (PrimState m) a)
newUnliftedArray len v = do
mua <- unsafeNewUnliftedArray len
setUnliftedArray mua v 0 len
pure mua
{-# inline newUnliftedArray #-}
setUnliftedArray
:: (PrimMonad m, PrimUnlifted a)
=> MutableUnliftedArray (PrimState m) a
-> a
-> Int
-> Int
-> m ()
{-# inline setUnliftedArray #-}
setUnliftedArray mua v off len = loop (len + off - 1)
where
loop i
| i < off = pure ()
| otherwise = writeUnliftedArray mua i v *> loop (i-1)
sizeofUnliftedArray :: UnliftedArray e -> Int
{-# inline sizeofUnliftedArray #-}
sizeofUnliftedArray (UnliftedArray aa#) = I# (Exts.sizeofArrayArray# aa#)
sizeofMutableUnliftedArray :: MutableUnliftedArray s e -> Int
{-# inline sizeofMutableUnliftedArray #-}
sizeofMutableUnliftedArray (MutableUnliftedArray maa#)
= I# (Exts.sizeofMutableArrayArray# maa#)
writeUnliftedArray :: (PrimMonad m, PrimUnlifted a)
=> MutableUnliftedArray (PrimState m) a
-> Int
-> a
-> m ()
{-# inline writeUnliftedArray #-}
writeUnliftedArray (MutableUnliftedArray arr) (I# ix) a =
primitive_ (C.writeUnliftedArray# arr ix a)
readUnliftedArray :: (PrimMonad m, PrimUnlifted a)
=> MutableUnliftedArray (PrimState m) a
-> Int
-> m a
{-# inline readUnliftedArray #-}
readUnliftedArray (MutableUnliftedArray arr) (I# ix) =
primitive (C.readUnliftedArray# arr ix)
indexUnliftedArray :: PrimUnlifted a
=> UnliftedArray a
-> Int
-> a
{-# inline indexUnliftedArray #-}
indexUnliftedArray (UnliftedArray arr) (I# ix) =
C.indexUnliftedArray# arr ix
unsafeFreezeUnliftedArray
:: (PrimMonad m)
=> MutableUnliftedArray (PrimState m) a
-> m (UnliftedArray a)
unsafeFreezeUnliftedArray (MutableUnliftedArray maa#)
= primitive $ \s -> case Exts.unsafeFreezeArrayArray# maa# s of
(# s', aa# #) -> (# s', UnliftedArray aa# #)
{-# inline unsafeFreezeUnliftedArray #-}
sameMutableUnliftedArray
:: MutableUnliftedArray s a
-> MutableUnliftedArray s a
-> Bool
sameMutableUnliftedArray (MutableUnliftedArray maa1#) (MutableUnliftedArray maa2#)
= Exts.isTrue# (Exts.sameMutableArrayArray# maa1# maa2#)
{-# inline sameMutableUnliftedArray #-}
copyUnliftedArray
:: (PrimMonad m)
=> MutableUnliftedArray (PrimState m) a
-> Int
-> UnliftedArray a
-> Int
-> Int
-> m ()
{-# inline copyUnliftedArray #-}
copyUnliftedArray
(MutableUnliftedArray dst) (I# doff)
(UnliftedArray src) (I# soff) (I# ln) =
primitive_ $ Exts.copyArrayArray# src soff dst doff ln
copyMutableUnliftedArray
:: (PrimMonad m)
=> MutableUnliftedArray (PrimState m) a
-> Int
-> MutableUnliftedArray (PrimState m) a
-> Int
-> Int
-> m ()
{-# inline copyMutableUnliftedArray #-}
copyMutableUnliftedArray
(MutableUnliftedArray dst) (I# doff)
(MutableUnliftedArray src) (I# soff) (I# ln) =
primitive_ $ Exts.copyMutableArrayArray# src soff dst doff ln
freezeUnliftedArray
:: (PrimMonad m)
=> MutableUnliftedArray (PrimState m) a
-> Int
-> Int
-> m (UnliftedArray a)
freezeUnliftedArray src off len = do
dst <- unsafeNewUnliftedArray len
copyMutableUnliftedArray dst 0 src off len
unsafeFreezeUnliftedArray dst
{-# inline freezeUnliftedArray #-}
thawUnliftedArray
:: (PrimMonad m)
=> UnliftedArray a
-> Int
-> Int
-> m (MutableUnliftedArray (PrimState m) a)
{-# inline thawUnliftedArray #-}
thawUnliftedArray src off len = do
dst <- unsafeNewUnliftedArray len
copyUnliftedArray dst 0 src off len
return dst
unsafeCreateUnliftedArray
:: Int
-> (forall s. MutableUnliftedArray s a -> ST s ())
-> UnliftedArray a
unsafeCreateUnliftedArray !n f = runUnliftedArray $ do
mary <- unsafeNewUnliftedArray n
f mary
pure mary
runUnliftedArray
:: (forall s. ST s (MutableUnliftedArray s a))
-> UnliftedArray a
{-# INLINE runUnliftedArray #-}
runUnliftedArray m = UnliftedArray (runUnliftedArray# m)
runUnliftedArray#
:: (forall s. ST s (MutableUnliftedArray s a))
-> ArrayArray#
runUnliftedArray# m = case Exts.runRW# $ \s ->
case unST m s of { (# s', MutableUnliftedArray mary# #) ->
Exts.unsafeFreezeArrayArray# mary# s'} of (# _, ary# #) -> ary#
unST :: ST s a -> State# s -> (# State# s, a #)
unST (ST.ST f) = f
cloneUnliftedArray
:: UnliftedArray a
-> Int
-> Int
-> UnliftedArray a
{-# inline cloneUnliftedArray #-}
cloneUnliftedArray src off len =
runUnliftedArray (thawUnliftedArray src off len)
cloneMutableUnliftedArray
:: (PrimMonad m)
=> MutableUnliftedArray (PrimState m) a
-> Int
-> Int
-> m (MutableUnliftedArray (PrimState m) a)
{-# inline cloneMutableUnliftedArray #-}
cloneMutableUnliftedArray src off len = do
dst <- unsafeNewUnliftedArray len
copyMutableUnliftedArray dst 0 src off len
return dst
emptyUnliftedArray :: UnliftedArray a
emptyUnliftedArray = runUnliftedArray (unsafeNewUnliftedArray 0)
{-# NOINLINE emptyUnliftedArray #-}
concatUnliftedArray :: UnliftedArray a -> UnliftedArray a -> UnliftedArray a
{-# INLINE concatUnliftedArray #-}
concatUnliftedArray x y = unsafeCreateUnliftedArray (sizeofUnliftedArray x + sizeofUnliftedArray y) $ \m -> do
copyUnliftedArray m 0 x 0 (sizeofUnliftedArray x)
copyUnliftedArray m (sizeofUnliftedArray x) y 0 (sizeofUnliftedArray y)
foldrUnliftedArray :: forall a b. PrimUnlifted a => (a -> b -> b) -> b -> UnliftedArray a -> b
{-# INLINE foldrUnliftedArray #-}
foldrUnliftedArray f z arr = go 0
where
!sz = sizeofUnliftedArray arr
go !i
| sz > i = f (indexUnliftedArray arr i) (go (i+1))
| otherwise = z
{-# INLINE foldrUnliftedArray' #-}
foldrUnliftedArray' :: forall a b. PrimUnlifted a => (a -> b -> b) -> b -> UnliftedArray a -> b
foldrUnliftedArray' f z0 arr = go (sizeofUnliftedArray arr - 1) z0
where
go !i !acc
| i < 0 = acc
| otherwise = go (i - 1) (f (indexUnliftedArray arr i) acc)
{-# INLINE foldlUnliftedArray #-}
foldlUnliftedArray :: forall a b. PrimUnlifted a => (b -> a -> b) -> b -> UnliftedArray a -> b
foldlUnliftedArray f z arr = go (sizeofUnliftedArray arr - 1)
where
go !i
| i < 0 = z
| otherwise = f (go (i - 1)) (indexUnliftedArray arr i)
{-# INLINE foldlUnliftedArray' #-}
foldlUnliftedArray' :: forall a b. PrimUnlifted a => (b -> a -> b) -> b -> UnliftedArray a -> b
foldlUnliftedArray' f z0 arr = go 0 z0
where
!sz = sizeofUnliftedArray arr
go !i !acc
| i < sz = go (i + 1) (f acc (indexUnliftedArray arr i))
| otherwise = acc
{-# INLINE foldlUnliftedArrayM' #-}
foldlUnliftedArrayM' :: (PrimUnlifted a, Monad m)
=> (b -> a -> m b) -> b -> UnliftedArray a -> m b
foldlUnliftedArrayM' f z0 arr = go 0 z0
where
!sz = sizeofUnliftedArray arr
go !i !acc
| i < sz = f acc (indexUnliftedArray arr i) >>= go (i + 1)
| otherwise = pure acc
{-# INLINE traverseUnliftedArray_ #-}
traverseUnliftedArray_ :: (PrimUnlifted a, Applicative m)
=> (a -> m b) -> UnliftedArray a -> m ()
traverseUnliftedArray_ f arr = go 0
where
!sz = sizeofUnliftedArray arr
go !i
| i < sz = f (indexUnliftedArray arr i) *> go (i + 1)
| otherwise = pure ()
{-# INLINE itraverseUnliftedArray_ #-}
itraverseUnliftedArray_ :: (PrimUnlifted a, Applicative m)
=> (Int -> a -> m b) -> UnliftedArray a -> m ()
itraverseUnliftedArray_ f arr = go 0
where
!sz = sizeofUnliftedArray arr
go !i
| i < sz = f i (indexUnliftedArray arr i) *> go (i + 1)
| otherwise = pure ()
{-# INLINE mapUnliftedArray #-}
mapUnliftedArray :: (PrimUnlifted a, PrimUnlifted b)
=> (a -> b)
-> UnliftedArray a
-> UnliftedArray b
mapUnliftedArray f arr = unsafeCreateUnliftedArray sz $ \marr -> do
let go !ix = if ix < sz
then do
let b = f (indexUnliftedArray arr ix)
writeUnliftedArray marr ix b
go (ix + 1)
else return ()
go 0
where
!sz = sizeofUnliftedArray arr
{-# INLINE unliftedArrayToList #-}
unliftedArrayToList :: PrimUnlifted a => UnliftedArray a -> [a]
unliftedArrayToList xs = Exts.build (\c n -> foldrUnliftedArray c n xs)
unliftedArrayFromList :: PrimUnlifted a => [a] -> UnliftedArray a
unliftedArrayFromList xs = unliftedArrayFromListN (L.length xs) xs
unliftedArrayFromListN :: forall a. PrimUnlifted a => Int -> [a] -> UnliftedArray a
unliftedArrayFromListN len vs = unsafeCreateUnliftedArray len run where
run :: forall s. MutableUnliftedArray s a -> ST s ()
run arr = do
let go :: [a] -> Int -> ST s ()
go [] !ix = if ix == len
then return ()
else die "unliftedArrayFromListN" "list length less than specified size"
go (a : as) !ix = if ix < len
then do
writeUnliftedArray arr ix a
go as (ix + 1)
else die "unliftedArrayFromListN" "list length greater than specified size"
go vs 0
instance PrimUnlifted a => Exts.IsList (UnliftedArray a) where
type Item (UnliftedArray a) = a
fromList = unliftedArrayFromList
fromListN = unliftedArrayFromListN
toList = unliftedArrayToList
instance PrimUnlifted a => Semigroup (UnliftedArray a) where
(<>) = concatUnliftedArray
instance PrimUnlifted a => Monoid (UnliftedArray a) where
mempty = emptyUnliftedArray
instance (Show a, PrimUnlifted a) => Show (UnliftedArray a) where
showsPrec p a = showParen (p > 10) $
showString "fromListN " . shows (sizeofUnliftedArray a) . showString " "
. shows (unliftedArrayToList a)
instance Eq (MutableUnliftedArray s a) where
(==) = sameMutableUnliftedArray
instance (Eq a, PrimUnlifted a) => Eq (UnliftedArray a) where
aa1 == aa2 = sizeofUnliftedArray aa1 == sizeofUnliftedArray aa2
&& loop (sizeofUnliftedArray aa1 - 1)
where
loop i
| i < 0 = True
| otherwise = indexUnliftedArray aa1 i == indexUnliftedArray aa2 i && loop (i-1)
die :: String -> String -> a
die fun problem = error $ "Data.Primitive.UnliftedArray." ++ fun ++ ": " ++ problem