{-# LANGUAGE CPP, MultiParamTypeClasses, FlexibleContexts, BangPatterns, TypeFamilies, ScopedTypeVariables #-}
module Data.Vector.Generic.Mutable (
MVector(..),
length, null,
slice, init, tail, take, drop, splitAt,
unsafeSlice, unsafeInit, unsafeTail, unsafeTake, unsafeDrop,
overlaps,
new, unsafeNew, replicate, replicateM, clone,
grow, unsafeGrow,
growFront, unsafeGrowFront,
clear,
read, write, modify, swap, exchange,
unsafeRead, unsafeWrite, unsafeModify, unsafeSwap, unsafeExchange,
set, copy, move, unsafeCopy, unsafeMove,
mstream, mstreamR,
unstream, unstreamR, vunstream,
munstream, munstreamR,
transform, transformR,
fill, fillR,
unsafeAccum, accum, unsafeUpdate, update, reverse,
unstablePartition, unstablePartitionBundle, partitionBundle
) where
import Data.Vector.Generic.Mutable.Base
import qualified Data.Vector.Generic.Base as V
import qualified Data.Vector.Fusion.Bundle as Bundle
import Data.Vector.Fusion.Bundle ( Bundle, MBundle, Chunk(..) )
import qualified Data.Vector.Fusion.Bundle.Monadic as MBundle
import Data.Vector.Fusion.Stream.Monadic ( Stream )
import qualified Data.Vector.Fusion.Stream.Monadic as Stream
import Data.Vector.Fusion.Bundle.Size
import Data.Vector.Fusion.Util ( delay_inline )
import Control.Monad.Primitive ( PrimMonad, PrimState )
import Prelude hiding ( length, null, replicate, reverse, map, read,
take, drop, splitAt, init, tail )
#include "vector.h"
unsafeAppend1 :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> Int -> a -> m (v (PrimState m) a)
{-# INLINE_INNER unsafeAppend1 #-}
unsafeAppend1 v i x
| i < length v = do
unsafeWrite v i x
return v
| otherwise = do
v' <- enlarge v
INTERNAL_CHECK(checkIndex) "unsafeAppend1" i (length v')
$ unsafeWrite v' i x
return v'
unsafePrepend1 :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> Int -> a -> m (v (PrimState m) a, Int)
{-# INLINE_INNER unsafePrepend1 #-}
unsafePrepend1 v i x
| i /= 0 = do
let i' = i-1
unsafeWrite v i' x
return (v, i')
| otherwise = do
(v', j) <- enlargeFront v
let i' = j-1
INTERNAL_CHECK(checkIndex) "unsafePrepend1" i' (length v')
$ unsafeWrite v' i' x
return (v', i')
mstream :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Stream m a
{-# INLINE mstream #-}
mstream v = v `seq` n `seq` (Stream.unfoldrM get 0)
where
n = length v
{-# INLINE_INNER get #-}
get i | i < n = do x <- unsafeRead v i
return $ Just (x, i+1)
| otherwise = return $ Nothing
fill :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> Stream m a -> m (v (PrimState m) a)
{-# INLINE fill #-}
fill v s = v `seq` do
n' <- Stream.foldM put 0 s
return $ unsafeSlice 0 n' v
where
{-# INLINE_INNER put #-}
put i x = do
INTERNAL_CHECK(checkIndex) "fill" i (length v)
$ unsafeWrite v i x
return (i+1)
transform
:: (PrimMonad m, MVector v a)
=> (Stream m a -> Stream m a) -> v (PrimState m) a -> m (v (PrimState m) a)
{-# INLINE_FUSED transform #-}
transform f v = fill v (f (mstream v))
mstreamR :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Stream m a
{-# INLINE mstreamR #-}
mstreamR v = v `seq` n `seq` (Stream.unfoldrM get n)
where
n = length v
{-# INLINE_INNER get #-}
get i | j >= 0 = do x <- unsafeRead v j
return $ Just (x,j)
| otherwise = return Nothing
where
j = i-1
fillR :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> Stream m a -> m (v (PrimState m) a)
{-# INLINE fillR #-}
fillR v s = v `seq` do
i <- Stream.foldM put n s
return $ unsafeSlice i (n-i) v
where
n = length v
{-# INLINE_INNER put #-}
put i x = do
unsafeWrite v j x
return j
where
j = i-1
transformR
:: (PrimMonad m, MVector v a)
=> (Stream m a -> Stream m a) -> v (PrimState m) a -> m (v (PrimState m) a)
{-# INLINE_FUSED transformR #-}
transformR f v = fillR v (f (mstreamR v))
unstream :: (PrimMonad m, MVector v a)
=> Bundle u a -> m (v (PrimState m) a)
{-# INLINE_FUSED unstream #-}
unstream s = munstream (Bundle.lift s)
munstream :: (PrimMonad m, MVector v a)
=> MBundle m u a -> m (v (PrimState m) a)
{-# INLINE_FUSED munstream #-}
munstream s = case upperBound (MBundle.size s) of
Just n -> munstreamMax s n
Nothing -> munstreamUnknown s
munstreamMax :: (PrimMonad m, MVector v a)
=> MBundle m u a -> Int -> m (v (PrimState m) a)
{-# INLINE munstreamMax #-}
munstreamMax s n
= do
v <- INTERNAL_CHECK(checkLength) "munstreamMax" n
$ unsafeNew n
let put i x = do
INTERNAL_CHECK(checkIndex) "munstreamMax" i n
$ unsafeWrite v i x
return (i+1)
n' <- MBundle.foldM' put 0 s
return $ INTERNAL_CHECK(checkSlice) "munstreamMax" 0 n' n
$ unsafeSlice 0 n' v
munstreamUnknown :: (PrimMonad m, MVector v a)
=> MBundle m u a -> m (v (PrimState m) a)
{-# INLINE munstreamUnknown #-}
munstreamUnknown s
= do
v <- unsafeNew 0
(v', n) <- MBundle.foldM put (v, 0) s
return $ INTERNAL_CHECK(checkSlice) "munstreamUnknown" 0 n (length v')
$ unsafeSlice 0 n v'
where
{-# INLINE_INNER put #-}
put (v,i) x = do
v' <- unsafeAppend1 v i x
return (v',i+1)
vunstream :: (PrimMonad m, V.Vector v a)
=> Bundle v a -> m (V.Mutable v (PrimState m) a)
{-# INLINE_FUSED vunstream #-}
vunstream s = vmunstream (Bundle.lift s)
vmunstream :: (PrimMonad m, V.Vector v a)
=> MBundle m v a -> m (V.Mutable v (PrimState m) a)
{-# INLINE_FUSED vmunstream #-}
vmunstream s = case upperBound (MBundle.size s) of
Just n -> vmunstreamMax s n
Nothing -> vmunstreamUnknown s
vmunstreamMax :: (PrimMonad m, V.Vector v a)
=> MBundle m v a -> Int -> m (V.Mutable v (PrimState m) a)
{-# INLINE vmunstreamMax #-}
vmunstreamMax s n
= do
v <- INTERNAL_CHECK(checkLength) "munstreamMax" n
$ unsafeNew n
let {-# INLINE_INNER copyChunk #-}
copyChunk i (Chunk m f) =
INTERNAL_CHECK(checkSlice) "munstreamMax.copyChunk" i m (length v) $ do
f (basicUnsafeSlice i m v)
return (i+m)
n' <- Stream.foldlM' copyChunk 0 (MBundle.chunks s)
return $ INTERNAL_CHECK(checkSlice) "munstreamMax" 0 n' n
$ unsafeSlice 0 n' v
vmunstreamUnknown :: (PrimMonad m, V.Vector v a)
=> MBundle m v a -> m (V.Mutable v (PrimState m) a)
{-# INLINE vmunstreamUnknown #-}
vmunstreamUnknown s
= do
v <- unsafeNew 0
(v', n) <- Stream.foldlM copyChunk (v,0) (MBundle.chunks s)
return $ INTERNAL_CHECK(checkSlice) "munstreamUnknown" 0 n (length v')
$ unsafeSlice 0 n v'
where
{-# INLINE_INNER copyChunk #-}
copyChunk (v,i) (Chunk n f)
= do
let j = i+n
v' <- if basicLength v < j
then unsafeGrow v (delay_inline max (enlarge_delta v) (j - basicLength v))
else return v
INTERNAL_CHECK(checkSlice) "munstreamUnknown.copyChunk" i n (length v')
$ f (basicUnsafeSlice i n v')
return (v',j)
unstreamR :: (PrimMonad m, MVector v a)
=> Bundle u a -> m (v (PrimState m) a)
{-# INLINE_FUSED unstreamR #-}
unstreamR s = munstreamR (Bundle.lift s)
munstreamR :: (PrimMonad m, MVector v a)
=> MBundle m u a -> m (v (PrimState m) a)
{-# INLINE_FUSED munstreamR #-}
munstreamR s = case upperBound (MBundle.size s) of
Just n -> munstreamRMax s n
Nothing -> munstreamRUnknown s
munstreamRMax :: (PrimMonad m, MVector v a)
=> MBundle m u a -> Int -> m (v (PrimState m) a)
{-# INLINE munstreamRMax #-}
munstreamRMax s n
= do
v <- INTERNAL_CHECK(checkLength) "munstreamRMax" n
$ unsafeNew n
let put i x = do
let i' = i-1
INTERNAL_CHECK(checkIndex) "munstreamRMax" i' n
$ unsafeWrite v i' x
return i'
i <- MBundle.foldM' put n s
return $ INTERNAL_CHECK(checkSlice) "munstreamRMax" i (n-i) n
$ unsafeSlice i (n-i) v
munstreamRUnknown :: (PrimMonad m, MVector v a)
=> MBundle m u a -> m (v (PrimState m) a)
{-# INLINE munstreamRUnknown #-}
munstreamRUnknown s
= do
v <- unsafeNew 0
(v', i) <- MBundle.foldM put (v, 0) s
let n = length v'
return $ INTERNAL_CHECK(checkSlice) "unstreamRUnknown" i (n-i) n
$ unsafeSlice i (n-i) v'
where
{-# INLINE_INNER put #-}
put (v,i) x = unsafePrepend1 v i x
length :: MVector v a => v s a -> Int
{-# INLINE length #-}
length = basicLength
null :: MVector v a => v s a -> Bool
{-# INLINE null #-}
null v = length v == 0
slice :: MVector v a => Int -> Int -> v s a -> v s a
{-# INLINE slice #-}
slice i n v = BOUNDS_CHECK(checkSlice) "slice" i n (length v)
$ unsafeSlice i n v
take :: MVector v a => Int -> v s a -> v s a
{-# INLINE take #-}
take n v = unsafeSlice 0 (min (max n 0) (length v)) v
drop :: MVector v a => Int -> v s a -> v s a
{-# INLINE drop #-}
drop n v = unsafeSlice (min m n') (max 0 (m - n')) v
where
n' = max n 0
m = length v
{-# INLINE splitAt #-}
splitAt :: MVector v a => Int -> v s a -> (v s a, v s a)
splitAt n v = ( unsafeSlice 0 m v
, unsafeSlice m (max 0 (len - n')) v
)
where
m = min n' len
n' = max n 0
len = length v
init :: MVector v a => v s a -> v s a
{-# INLINE init #-}
init v = slice 0 (length v - 1) v
tail :: MVector v a => v s a -> v s a
{-# INLINE tail #-}
tail v = slice 1 (length v - 1) v
unsafeSlice :: MVector v a => Int
-> Int
-> v s a
-> v s a
{-# INLINE unsafeSlice #-}
unsafeSlice i n v = UNSAFE_CHECK(checkSlice) "unsafeSlice" i n (length v)
$ basicUnsafeSlice i n v
unsafeInit :: MVector v a => v s a -> v s a
{-# INLINE unsafeInit #-}
unsafeInit v = unsafeSlice 0 (length v - 1) v
unsafeTail :: MVector v a => v s a -> v s a
{-# INLINE unsafeTail #-}
unsafeTail v = unsafeSlice 1 (length v - 1) v
unsafeTake :: MVector v a => Int -> v s a -> v s a
{-# INLINE unsafeTake #-}
unsafeTake n v = unsafeSlice 0 n v
unsafeDrop :: MVector v a => Int -> v s a -> v s a
{-# INLINE unsafeDrop #-}
unsafeDrop n v = unsafeSlice n (length v - n) v
overlaps :: MVector v a => v s a -> v s a -> Bool
{-# INLINE overlaps #-}
overlaps = basicOverlaps
new :: (PrimMonad m, MVector v a) => Int -> m (v (PrimState m) a)
{-# INLINE new #-}
new n = BOUNDS_CHECK(checkLength) "new" n
$ unsafeNew n >>= \v -> basicInitialize v >> return v
unsafeNew :: (PrimMonad m, MVector v a) => Int -> m (v (PrimState m) a)
{-# INLINE unsafeNew #-}
unsafeNew n = UNSAFE_CHECK(checkLength) "unsafeNew" n
$ basicUnsafeNew n
replicate :: (PrimMonad m, MVector v a) => Int -> a -> m (v (PrimState m) a)
{-# INLINE replicate #-}
replicate n x = basicUnsafeReplicate (delay_inline max 0 n) x
replicateM :: (PrimMonad m, MVector v a) => Int -> m a -> m (v (PrimState m) a)
{-# INLINE replicateM #-}
replicateM n m = munstream (MBundle.replicateM n m)
clone :: (PrimMonad m, MVector v a) => v (PrimState m) a -> m (v (PrimState m) a)
{-# INLINE clone #-}
clone v = do
v' <- unsafeNew (length v)
unsafeCopy v' v
return v'
grow :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> Int -> m (v (PrimState m) a)
{-# INLINE grow #-}
grow v by = BOUNDS_CHECK(checkLength) "grow" by
$ do vnew <- unsafeGrow v by
basicInitialize $ basicUnsafeSlice (length v) by vnew
return vnew
growFront :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> Int -> m (v (PrimState m) a)
{-# INLINE growFront #-}
growFront v by = BOUNDS_CHECK(checkLength) "growFront" by
$ do vnew <- unsafeGrowFront v by
basicInitialize $ basicUnsafeSlice 0 by vnew
return vnew
enlarge_delta :: MVector v a => v s a -> Int
enlarge_delta v = max (length v) 1
enlarge :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> m (v (PrimState m) a)
{-# INLINE enlarge #-}
enlarge v = do vnew <- unsafeGrow v by
basicInitialize $ basicUnsafeSlice (length v) by vnew
return vnew
where
by = enlarge_delta v
enlargeFront :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> m (v (PrimState m) a, Int)
{-# INLINE enlargeFront #-}
enlargeFront v = do
v' <- unsafeGrowFront v by
basicInitialize $ basicUnsafeSlice 0 by v'
return (v', by)
where
by = enlarge_delta v
unsafeGrow :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> Int -> m (v (PrimState m) a)
{-# INLINE unsafeGrow #-}
unsafeGrow v n = UNSAFE_CHECK(checkLength) "unsafeGrow" n
$ basicUnsafeGrow v n
unsafeGrowFront :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> Int -> m (v (PrimState m) a)
{-# INLINE unsafeGrowFront #-}
unsafeGrowFront v by = UNSAFE_CHECK(checkLength) "unsafeGrowFront" by
$ do
let n = length v
v' <- basicUnsafeNew (by+n)
basicUnsafeCopy (basicUnsafeSlice by n v') v
return v'
clear :: (PrimMonad m, MVector v a) => v (PrimState m) a -> m ()
{-# INLINE clear #-}
clear = basicClear
read :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> m a
{-# INLINE read #-}
read v i = BOUNDS_CHECK(checkIndex) "read" i (length v)
$ unsafeRead v i
write :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> a -> m ()
{-# INLINE write #-}
write v i x = BOUNDS_CHECK(checkIndex) "write" i (length v)
$ unsafeWrite v i x
modify :: (PrimMonad m, MVector v a) => v (PrimState m) a -> (a -> a) -> Int -> m ()
{-# INLINE modify #-}
modify v f i = BOUNDS_CHECK(checkIndex) "modify" i (length v)
$ unsafeModify v f i
swap :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> Int -> m ()
{-# INLINE swap #-}
swap v i j = BOUNDS_CHECK(checkIndex) "swap" i (length v)
$ BOUNDS_CHECK(checkIndex) "swap" j (length v)
$ unsafeSwap v i j
exchange :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> a -> m a
{-# INLINE exchange #-}
exchange v i x = BOUNDS_CHECK(checkIndex) "exchange" i (length v)
$ unsafeExchange v i x
unsafeRead :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> m a
{-# INLINE unsafeRead #-}
unsafeRead v i = UNSAFE_CHECK(checkIndex) "unsafeRead" i (length v)
$ basicUnsafeRead v i
unsafeWrite :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> Int -> a -> m ()
{-# INLINE unsafeWrite #-}
unsafeWrite v i x = UNSAFE_CHECK(checkIndex) "unsafeWrite" i (length v)
$ basicUnsafeWrite v i x
unsafeModify :: (PrimMonad m, MVector v a) => v (PrimState m) a -> (a -> a) -> Int -> m ()
{-# INLINE unsafeModify #-}
unsafeModify v f i = UNSAFE_CHECK(checkIndex) "unsafeModify" i (length v)
$ basicUnsafeRead v i >>= \x ->
basicUnsafeWrite v i (f x)
unsafeSwap :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> Int -> Int -> m ()
{-# INLINE unsafeSwap #-}
unsafeSwap v i j = UNSAFE_CHECK(checkIndex) "unsafeSwap" i (length v)
$ UNSAFE_CHECK(checkIndex) "unsafeSwap" j (length v)
$ do
x <- unsafeRead v i
y <- unsafeRead v j
unsafeWrite v i y
unsafeWrite v j x
unsafeExchange :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> Int -> a -> m a
{-# INLINE unsafeExchange #-}
unsafeExchange v i x = UNSAFE_CHECK(checkIndex) "unsafeExchange" i (length v)
$ do
y <- unsafeRead v i
unsafeWrite v i x
return y
set :: (PrimMonad m, MVector v a) => v (PrimState m) a -> a -> m ()
{-# INLINE set #-}
set = basicSet
copy :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> v (PrimState m) a -> m ()
{-# INLINE copy #-}
copy dst src = BOUNDS_CHECK(check) "copy" "overlapping vectors"
(not (dst `overlaps` src))
$ BOUNDS_CHECK(check) "copy" "length mismatch"
(length dst == length src)
$ unsafeCopy dst src
move :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> v (PrimState m) a -> m ()
{-# INLINE move #-}
move dst src = BOUNDS_CHECK(check) "move" "length mismatch"
(length dst == length src)
$ unsafeMove dst src
unsafeCopy :: (PrimMonad m, MVector v a) => v (PrimState m) a
-> v (PrimState m) a
-> m ()
{-# INLINE unsafeCopy #-}
unsafeCopy dst src = UNSAFE_CHECK(check) "unsafeCopy" "length mismatch"
(length dst == length src)
$ UNSAFE_CHECK(check) "unsafeCopy" "overlapping vectors"
(not (dst `overlaps` src))
$ (dst `seq` src `seq` basicUnsafeCopy dst src)
unsafeMove :: (PrimMonad m, MVector v a) => v (PrimState m) a
-> v (PrimState m) a
-> m ()
{-# INLINE unsafeMove #-}
unsafeMove dst src = UNSAFE_CHECK(check) "unsafeMove" "length mismatch"
(length dst == length src)
$ (dst `seq` src `seq` basicUnsafeMove dst src)
accum :: (PrimMonad m, MVector v a)
=> (a -> b -> a) -> v (PrimState m) a -> Bundle u (Int, b) -> m ()
{-# INLINE accum #-}
accum f !v s = Bundle.mapM_ upd s
where
{-# INLINE_INNER upd #-}
upd (i,b) = do
a <- BOUNDS_CHECK(checkIndex) "accum" i n
$ unsafeRead v i
unsafeWrite v i (f a b)
!n = length v
update :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> Bundle u (Int, a) -> m ()
{-# INLINE update #-}
update !v s = Bundle.mapM_ upd s
where
{-# INLINE_INNER upd #-}
upd (i,b) = BOUNDS_CHECK(checkIndex) "update" i n
$ unsafeWrite v i b
!n = length v
unsafeAccum :: (PrimMonad m, MVector v a)
=> (a -> b -> a) -> v (PrimState m) a -> Bundle u (Int, b) -> m ()
{-# INLINE unsafeAccum #-}
unsafeAccum f !v s = Bundle.mapM_ upd s
where
{-# INLINE_INNER upd #-}
upd (i,b) = do
a <- UNSAFE_CHECK(checkIndex) "accum" i n
$ unsafeRead v i
unsafeWrite v i (f a b)
!n = length v
unsafeUpdate :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> Bundle u (Int, a) -> m ()
{-# INLINE unsafeUpdate #-}
unsafeUpdate !v s = Bundle.mapM_ upd s
where
{-# INLINE_INNER upd #-}
upd (i,b) = UNSAFE_CHECK(checkIndex) "accum" i n
$ unsafeWrite v i b
!n = length v
reverse :: (PrimMonad m, MVector v a) => v (PrimState m) a -> m ()
{-# INLINE reverse #-}
reverse !v = reverse_loop 0 (length v - 1)
where
reverse_loop i j | i < j = do
unsafeSwap v i j
reverse_loop (i + 1) (j - 1)
reverse_loop _ _ = return ()
unstablePartition :: forall m v a. (PrimMonad m, MVector v a)
=> (a -> Bool) -> v (PrimState m) a -> m Int
{-# INLINE unstablePartition #-}
unstablePartition f !v = from_left 0 (length v)
where
from_left :: Int -> Int -> m Int
from_left i j
| i == j = return i
| otherwise = do
x <- unsafeRead v i
if f x
then from_left (i+1) j
else from_right i (j-1)
from_right :: Int -> Int -> m Int
from_right i j
| i == j = return i
| otherwise = do
x <- unsafeRead v j
if f x
then do
y <- unsafeRead v i
unsafeWrite v i x
unsafeWrite v j y
from_left (i+1) j
else from_right i (j-1)
unstablePartitionBundle :: (PrimMonad m, MVector v a)
=> (a -> Bool) -> Bundle u a -> m (v (PrimState m) a, v (PrimState m) a)
{-# INLINE unstablePartitionBundle #-}
unstablePartitionBundle f s
= case upperBound (Bundle.size s) of
Just n -> unstablePartitionMax f s n
Nothing -> partitionUnknown f s
unstablePartitionMax :: (PrimMonad m, MVector v a)
=> (a -> Bool) -> Bundle u a -> Int
-> m (v (PrimState m) a, v (PrimState m) a)
{-# INLINE unstablePartitionMax #-}
unstablePartitionMax f s n
= do
v <- INTERNAL_CHECK(checkLength) "unstablePartitionMax" n
$ unsafeNew n
let {-# INLINE_INNER put #-}
put (i, j) x
| f x = do
unsafeWrite v i x
return (i+1, j)
| otherwise = do
unsafeWrite v (j-1) x
return (i, j-1)
(i,j) <- Bundle.foldM' put (0, n) s
return (unsafeSlice 0 i v, unsafeSlice j (n-j) v)
partitionBundle :: (PrimMonad m, MVector v a)
=> (a -> Bool) -> Bundle u a -> m (v (PrimState m) a, v (PrimState m) a)
{-# INLINE partitionBundle #-}
partitionBundle f s
= case upperBound (Bundle.size s) of
Just n -> partitionMax f s n
Nothing -> partitionUnknown f s
partitionMax :: (PrimMonad m, MVector v a)
=> (a -> Bool) -> Bundle u a -> Int -> m (v (PrimState m) a, v (PrimState m) a)
{-# INLINE partitionMax #-}
partitionMax f s n
= do
v <- INTERNAL_CHECK(checkLength) "unstablePartitionMax" n
$ unsafeNew n
let {-# INLINE_INNER put #-}
put (i,j) x
| f x = do
unsafeWrite v i x
return (i+1,j)
| otherwise = let j' = j-1 in
do
unsafeWrite v j' x
return (i,j')
(i,j) <- Bundle.foldM' put (0,n) s
INTERNAL_CHECK(check) "partitionMax" "invalid indices" (i <= j)
$ return ()
let l = unsafeSlice 0 i v
r = unsafeSlice j (n-j) v
reverse r
return (l,r)
partitionUnknown :: (PrimMonad m, MVector v a)
=> (a -> Bool) -> Bundle u a -> m (v (PrimState m) a, v (PrimState m) a)
{-# INLINE partitionUnknown #-}
partitionUnknown f s
= do
v1 <- unsafeNew 0
v2 <- unsafeNew 0
(v1', n1, v2', n2) <- Bundle.foldM' put (v1, 0, v2, 0) s
INTERNAL_CHECK(checkSlice) "partitionUnknown" 0 n1 (length v1')
$ INTERNAL_CHECK(checkSlice) "partitionUnknown" 0 n2 (length v2')
$ return (unsafeSlice 0 n1 v1', unsafeSlice 0 n2 v2')
where
{-# INLINE_INNER put #-}
put (v1, i1, v2, i2) x
| f x = do
v1' <- unsafeAppend1 v1 i1 x
return (v1', i1+1, v2, i2)
| otherwise = do
v2' <- unsafeAppend1 v2 i2 x
return (v1, i1, v2', i2+1)