module Foundation.Primitive.Block
( Block(..)
, MutableBlock(..)
, length
, unsafeThaw
, unsafeFreeze
, unsafeIndex
, thaw
, freeze
, copy
, create
, singleton
, replicate
, index
, map
, foldl
, foldl'
, foldr
, cons
, snoc
, uncons
, unsnoc
, sub
, splitAt
, revSplitAt
, splitOn
, break
, span
, elem
, all
, any
, find
, filter
, reverse
, sortBy
, intersperse
, unsafeCopyToPtr
) where
import GHC.Prim
import GHC.Types
import GHC.ST
import qualified Data.List
import Foundation.Internal.Base
import Foundation.Internal.Proxy
import Foundation.Internal.Primitive
import Foundation.Primitive.Types.OffsetSize
import Foundation.Primitive.Monad
import Foundation.Primitive.Exception
import Foundation.Primitive.Types
import qualified Foundation.Primitive.Block.Mutable as M
import Foundation.Primitive.Block.Mutable (Block(..), MutableBlock(..), new, unsafeThaw, unsafeFreeze)
import Foundation.Primitive.Block.Base
import Foundation.Numerical
unsafeCopyToPtr :: forall ty prim . PrimMonad prim
=> Block ty
-> Ptr ty
-> prim ()
unsafeCopyToPtr (Block blk) (Ptr p) = primitive $ \s1 ->
(# compatCopyByteArrayToAddr# blk 0# p (sizeofByteArray# blk) s1, () #)
create :: forall ty . PrimType ty
=> CountOf ty
-> (Offset ty -> ty)
-> Block ty
create n initializer
| n == 0 = mempty
| otherwise = runST $ do
mb <- new n
M.iterSet initializer mb
unsafeFreeze mb
singleton :: PrimType ty => ty -> Block ty
singleton ty = create 1 (const ty)
replicate :: PrimType ty => CountOf ty -> ty -> Block ty
replicate sz ty = create sz (const ty)
thaw :: (PrimMonad prim, PrimType ty) => Block ty -> prim (MutableBlock ty (PrimState prim))
thaw array = do
ma <- M.unsafeNew unpinned (lengthBytes array)
M.unsafeCopyBytesRO ma 0 array 0 (lengthBytes array)
return ma
freeze :: (PrimType ty, PrimMonad prim) => MutableBlock ty (PrimState prim) -> prim (Block ty)
freeze ma = do
ma' <- unsafeNew unpinned len
M.unsafeCopyBytes ma' 0 ma 0 len
unsafeFreeze ma'
where
len = M.mutableLengthBytes ma
copy :: PrimType ty => Block ty -> Block ty
copy array = runST (thaw array >>= unsafeFreeze)
index :: PrimType ty => Block ty -> Offset ty -> ty
index array n
| isOutOfBound n len = outOfBound OOB_Index n len
| otherwise = unsafeIndex array n
where
!len = length array
map :: (PrimType a, PrimType b) => (a -> b) -> Block a -> Block b
map f a = create lenB (\i -> f $ unsafeIndex a (offsetCast Proxy i))
where !lenB = sizeCast (Proxy :: Proxy (a -> b)) (length a)
foldl :: PrimType ty => (a -> ty -> a) -> a -> Block ty -> a
foldl f initialAcc vec = loop 0 initialAcc
where
!len = length vec
loop i acc
| i .==# len = acc
| otherwise = loop (i+1) (f acc (unsafeIndex vec i))
foldr :: PrimType ty => (ty -> a -> a) -> a -> Block ty -> a
foldr f initialAcc vec = loop 0
where
!len = length vec
loop i
| i .==# len = initialAcc
| otherwise = unsafeIndex vec i `f` loop (i+1)
foldl' :: PrimType ty => (a -> ty -> a) -> a -> Block ty -> a
foldl' f initialAcc vec = loop 0 initialAcc
where
!len = length vec
loop i !acc
| i .==# len = acc
| otherwise = loop (i+1) (f acc (unsafeIndex vec i))
cons :: PrimType ty => ty -> Block ty -> Block ty
cons e vec
| len == 0 = singleton e
| otherwise = runST $ do
muv <- new (len + 1)
M.unsafeCopyElementsRO muv 1 vec 0 len
M.unsafeWrite muv 0 e
unsafeFreeze muv
where
!len = length vec
snoc :: PrimType ty => Block ty -> ty -> Block ty
snoc vec e
| len == CountOf 0 = singleton e
| otherwise = runST $ do
muv <- new (len + 1)
M.unsafeCopyElementsRO muv 0 vec 0 len
M.unsafeWrite muv (0 `offsetPlusE` length vec) e
unsafeFreeze muv
where
!len = length vec
sub :: PrimType ty => Block ty -> Offset ty -> Offset ty -> Block ty
sub blk start end
| start >= end' = mempty
| otherwise = runST $ do
dst <- new newLen
M.unsafeCopyElementsRO dst 0 blk start newLen
unsafeFreeze dst
where
newLen = end' start
end' = min (sizeAsOffset len) end
!len = length blk
uncons :: PrimType ty => Block ty -> Maybe (ty, Block ty)
uncons vec
| nbElems == 0 = Nothing
| otherwise = Just (unsafeIndex vec 0, sub vec 1 (0 `offsetPlusE` nbElems))
where
!nbElems = length vec
unsnoc :: PrimType ty => Block ty -> Maybe (Block ty, ty)
unsnoc vec
| nbElems == 0 = Nothing
| otherwise = Just (sub vec 0 lastElem, unsafeIndex vec lastElem)
where
!lastElem = 0 `offsetPlusE` (nbElems 1)
!nbElems = length vec
splitAt :: PrimType ty => CountOf ty -> Block ty -> (Block ty, Block ty)
splitAt nbElems blk
| nbElems <= 0 = (mempty, blk)
| n == vlen = (blk, mempty)
| otherwise = runST $ do
left <- new nbElems
right <- new (vlen nbElems)
M.unsafeCopyElementsRO left 0 blk 0 nbElems
M.unsafeCopyElementsRO right 0 blk (sizeAsOffset nbElems) (vlen nbElems)
(,) <$> unsafeFreeze left <*> unsafeFreeze right
where
n = min nbElems vlen
vlen = length blk
revSplitAt :: PrimType ty => CountOf ty -> Block ty -> (Block ty, Block ty)
revSplitAt n blk
| n <= 0 = (mempty, blk)
| otherwise = let (x,y) = splitAt (length blk n) blk in (y,x)
break :: PrimType ty => (ty -> Bool) -> Block ty -> (Block ty, Block ty)
break predicate blk = findBreak 0
where
!len = length blk
findBreak !i
| i .==# len = (blk, mempty)
| predicate (unsafeIndex blk i) = splitAt (offsetAsSize i) blk
| otherwise = findBreak (i + 1)
span :: PrimType ty => (ty -> Bool) -> Block ty -> (Block ty, Block ty)
span p = break (not . p)
elem :: PrimType ty => ty -> Block ty -> Bool
elem v blk = loop 0
where
!len = length blk
loop i
| i .==# len = False
| unsafeIndex blk i == v = True
| otherwise = loop (i+1)
all :: PrimType ty => (ty -> Bool) -> Block ty -> Bool
all p blk = loop 0
where
!len = length blk
loop i
| i .==# len = True
| p (unsafeIndex blk i) = loop (i+1)
| otherwise = False
any :: PrimType ty => (ty -> Bool) -> Block ty -> Bool
any p blk = loop 0
where
!len = length blk
loop i
| i .==# len = False
| p (unsafeIndex blk i) = True
| otherwise = loop (i+1)
splitOn :: PrimType ty => (ty -> Bool) -> Block ty -> [Block ty]
splitOn predicate blk
| len == 0 = [mempty]
| otherwise = go 0 0
where
!len = length blk
go !prevIdx !idx
| idx .==# len = [sub blk prevIdx idx]
| otherwise =
let e = unsafeIndex blk idx
idx' = idx + 1
in if predicate e
then sub blk prevIdx idx : go idx' idx'
else go prevIdx idx'
find :: PrimType ty => (ty -> Bool) -> Block ty -> Maybe ty
find predicate vec = loop 0
where
!len = length vec
loop i
| i .==# len = Nothing
| otherwise =
let e = unsafeIndex vec i
in if predicate e then Just e else loop (i+1)
filter :: PrimType ty => (ty -> Bool) -> Block ty -> Block ty
filter predicate vec = fromList $ Data.List.filter predicate $ toList vec
reverse :: forall ty . PrimType ty => Block ty -> Block ty
reverse blk
| len == 0 = mempty
| otherwise = runST $ do
mb <- new len
go mb
unsafeFreeze mb
where
!len = length blk
!endOfs = 0 `offsetPlusE` len
go :: MutableBlock ty s -> ST s ()
go mb = loop endOfs 0
where
loop o i
| i .==# len = pure ()
| otherwise = unsafeWrite mb o' (unsafeIndex blk i) >> loop o' (i+1)
where o' = pred o
sortBy :: forall ty . PrimType ty => (ty -> ty -> Ordering) -> Block ty -> Block ty
sortBy xford vec
| len == 0 = mempty
| otherwise = runST (thaw vec >>= doSort xford)
where
len = length vec
doSort :: (PrimType ty, PrimMonad prim) => (ty -> ty -> Ordering) -> MutableBlock ty (PrimState prim) -> prim (Block ty)
doSort ford ma = qsort 0 (sizeLastOffset len) >> unsafeFreeze ma
where
qsort lo hi
| lo >= hi = return ()
| otherwise = do
p <- partition lo hi
qsort lo (pred p)
qsort (p+1) hi
partition lo hi = do
pivot <- unsafeRead ma hi
let loop i j
| j == hi = pure i
| otherwise = do
aj <- unsafeRead ma j
i' <- if ford aj pivot == GT
then pure i
else do
ai <- unsafeRead ma i
unsafeWrite ma j ai
unsafeWrite ma i aj
pure $ i + 1
loop i' (j+1)
i <- loop lo lo
ai <- unsafeRead ma i
ahi <- unsafeRead ma hi
unsafeWrite ma hi ai
unsafeWrite ma i ahi
pure i
intersperse :: forall ty . PrimType ty => ty -> Block ty -> Block ty
intersperse sep blk
| len <= 1 = blk
| otherwise = runST $ do
mb <- new newSize
go mb
unsafeFreeze mb
where
!len = length blk
newSize = len + len 1
go :: MutableBlock ty s -> ST s ()
go mb = loop 0 0
where
loop !o !i
| i .==# (len 1) = unsafeWrite mb o (unsafeIndex blk i)
| otherwise = do
unsafeWrite mb o (unsafeIndex blk i)
unsafeWrite mb (o+1) sep
loop (o+2) (i+1)