{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Data.Massiv.Array.Mutable
(
msize
, read
, readM
, read'
, write
, writeM
, write'
, modify
, modifyM
, modifyM_
, modify'
, swap
, swapM
, swapM_
, swap'
, new
, thaw
, thawS
, freeze
, freezeS
, makeMArray
, makeMArrayLinear
, makeMArrayS
, makeMArrayLinearS
, createArray_
, createArray
, createArrayS_
, createArrayS
, createArrayST_
, createArrayST
, generateArray
, generateArrayLinear
, generateArrayS
, generateArrayLinearS
, generateArrayWS
, generateArrayLinearWS
, unfoldrPrimM_
, iunfoldrPrimM_
, unfoldrPrimM
, iunfoldrPrimM
, unfoldlPrimM_
, iunfoldlPrimM_
, unfoldlPrimM
, iunfoldlPrimM
, forPrimM
, forPrimM_
, iforPrimM
, iforPrimM_
, iforLinearPrimM
, iforLinearPrimM_
, withMArray
, withMArrayS
, withMArrayST
, initialize
, initializeNew
, Mutable
, MArray
, RealWorld
, computeInto
, loadArray
, loadArrayS
) where
import Data.Maybe (fromMaybe)
import Control.Monad (void, when, unless, (>=>))
import Control.Monad.ST
import Control.Scheduler
import Data.Massiv.Core.Common
import Prelude hiding (mapM, read)
new ::
forall r ix e m. (Mutable r ix e, PrimMonad m)
=> Sz ix
-> m (MArray (PrimState m) r ix e)
new = initializeNew Nothing
{-# INLINE new #-}
thaw :: forall r ix e m. (Mutable r ix e, MonadIO m) => Array r ix e -> m (MArray RealWorld r ix e)
thaw arr =
liftIO $ do
let sz = size arr
totalLength = totalElem sz
marr <- unsafeNew sz
withScheduler_ (getComp arr) $ \scheduler ->
splitLinearly (numWorkers scheduler) totalLength $ \chunkLength slackStart -> do
loopM_ 0 (< slackStart) (+ chunkLength) $ \ !start ->
scheduleWork_ scheduler $ unsafeArrayLinearCopy arr start marr start (SafeSz chunkLength)
let slackLength = totalLength - slackStart
when (slackLength > 0) $
scheduleWork_ scheduler $
unsafeArrayLinearCopy arr slackStart marr slackStart (SafeSz slackLength)
pure marr
{-# INLINE thaw #-}
thawS ::
forall r ix e m. (Mutable r ix e, PrimMonad m)
=> Array r ix e
-> m (MArray (PrimState m) r ix e)
thawS arr = do
tmarr <- unsafeNew (size arr)
unsafeArrayLinearCopy arr 0 tmarr 0 (SafeSz (totalElem (size arr)))
pure tmarr
{-# INLINE thawS #-}
freeze ::
forall r ix e m. (Mutable r ix e, MonadIO m)
=> Comp
-> MArray RealWorld r ix e
-> m (Array r ix e)
freeze comp smarr =
liftIO $ do
let sz = msize smarr
totalLength = totalElem sz
tmarr <- unsafeNew sz
withScheduler_ comp $ \scheduler ->
splitLinearly (numWorkers scheduler) totalLength $ \chunkLength slackStart -> do
loopM_ 0 (< slackStart) (+ chunkLength) $ \ !start ->
scheduleWork_ scheduler $ unsafeLinearCopy smarr start tmarr start (SafeSz chunkLength)
let slackLength = totalLength - slackStart
when (slackLength > 0) $
scheduleWork_ scheduler $
unsafeLinearCopy smarr slackStart tmarr slackStart (SafeSz slackLength)
unsafeFreeze comp tmarr
{-# INLINE freeze #-}
freezeS ::
forall r ix e m. (Mutable r ix e, PrimMonad m)
=> MArray (PrimState m) r ix e
-> m (Array r ix e)
freezeS smarr = do
let sz = msize smarr
tmarr <- unsafeNew sz
unsafeLinearCopy smarr 0 tmarr 0 (SafeSz (totalElem sz))
unsafeFreeze Seq tmarr
{-# INLINE freezeS #-}
newMaybeInitialized ::
(Load r' ix e, Mutable r ix e, PrimMonad m) => Array r' ix e -> m (MArray (PrimState m) r ix e)
newMaybeInitialized !arr = initializeNew (defaultElement arr) (fromMaybe zeroSz (maxSize arr))
{-# INLINE newMaybeInitialized #-}
loadArrayS ::
forall r ix e r' m. (Load r' ix e, Mutable r ix e, PrimMonad m)
=> Array r' ix e
-> m (MArray (PrimState m) r ix e)
loadArrayS arr = do
marr <- newMaybeInitialized arr
unsafeLoadIntoS marr arr
{-# INLINE loadArrayS #-}
loadArray ::
forall r ix e r' m. (Load r' ix e, Mutable r ix e, MonadIO m)
=> Array r' ix e
-> m (MArray RealWorld r ix e)
loadArray arr =
liftIO $ do
marr <- newMaybeInitialized arr
unsafeLoadInto marr arr
{-# INLINE loadArray #-}
computeInto ::
(Load r' ix' e, Mutable r ix e, MonadIO m)
=> MArray RealWorld r ix e
-> Array r' ix' e
-> m ()
computeInto !mArr !arr =
liftIO $ do
unless (totalElem (msize mArr) == totalElem (size arr)) $
throwM $ SizeElementsMismatchException (msize mArr) (size arr)
withScheduler_ (getComp arr) $ \scheduler -> loadArrayM scheduler arr (unsafeLinearWrite mArr)
{-# INLINE computeInto #-}
makeMArrayS ::
forall r ix e m. (Mutable r ix e, PrimMonad m)
=> Sz ix
-> (ix -> m e)
-> m (MArray (PrimState m) r ix e)
makeMArrayS sz f = makeMArrayLinearS sz (f . fromLinearIndex sz)
{-# INLINE makeMArrayS #-}
makeMArrayLinearS ::
forall r ix e m. (Mutable r ix e, PrimMonad m)
=> Sz ix
-> (Int -> m e)
-> m (MArray (PrimState m) r ix e)
makeMArrayLinearS sz f = do
marr <- unsafeNew sz
loopM_ 0 (< totalElem (msize marr)) (+ 1) (\ !i -> f i >>= unsafeLinearWrite marr i)
return marr
{-# INLINE makeMArrayLinearS #-}
makeMArray ::
forall r ix e m. (PrimMonad m, MonadUnliftIO m, Mutable r ix e)
=> Comp
-> Sz ix
-> (ix -> m e)
-> m (MArray (PrimState m) r ix e)
makeMArray comp sz f = makeMArrayLinear comp sz (f . fromLinearIndex sz)
{-# INLINE makeMArray #-}
makeMArrayLinear ::
forall r ix e m. (PrimMonad m, MonadUnliftIO m, Mutable r ix e)
=> Comp
-> Sz ix
-> (Int -> m e)
-> m (MArray (PrimState m) r ix e)
makeMArrayLinear comp sz f = do
marr <- unsafeNew sz
withScheduler_ comp $ \scheduler ->
splitLinearlyWithM_ scheduler (totalElem sz) f (unsafeLinearWrite marr)
return marr
{-# INLINE makeMArrayLinear #-}
createArray_ ::
forall r ix e a m. (Mutable r ix e, PrimMonad m, MonadUnliftIO m)
=> Comp
-> Sz ix
-> (Scheduler m () -> MArray (PrimState m) r ix e -> m a)
-> m (Array r ix e)
createArray_ comp sz action = do
marr <- new sz
withScheduler_ comp (`action` marr)
unsafeFreeze comp marr
{-# INLINE createArray_ #-}
createArray ::
forall r ix e a m b. (Mutable r ix e, PrimMonad m, MonadUnliftIO m)
=> Comp
-> Sz ix
-> (Scheduler m a -> MArray (PrimState m) r ix e -> m b)
-> m ([a], Array r ix e)
createArray comp sz action = do
marr <- new sz
a <- withScheduler comp (`action` marr)
arr <- unsafeFreeze comp marr
return (a, arr)
{-# INLINE createArray #-}
createArrayS_ ::
forall r ix e a m. (Mutable r ix e, PrimMonad m)
=> Comp
-> Sz ix
-> (MArray (PrimState m) r ix e -> m a)
-> m (Array r ix e)
createArrayS_ comp sz action = snd <$> createArrayS comp sz action
{-# INLINE createArrayS_ #-}
createArrayS ::
forall r ix e a m. (Mutable r ix e, PrimMonad m)
=> Comp
-> Sz ix
-> (MArray (PrimState m) r ix e -> m a)
-> m (a, Array r ix e)
createArrayS comp sz action = do
marr <- new sz
a <- action marr
arr <- unsafeFreeze comp marr
return (a, arr)
{-# INLINE createArrayS #-}
createArrayST_ ::
forall r ix e a. Mutable r ix e
=> Comp
-> Sz ix
-> (forall s. MArray s r ix e -> ST s a)
-> Array r ix e
createArrayST_ comp sz action = runST $ createArrayS_ comp sz action
{-# INLINE createArrayST_ #-}
createArrayST ::
forall r ix e a. Mutable r ix e
=> Comp
-> Sz ix
-> (forall s. MArray s r ix e -> ST s a)
-> (a, Array r ix e)
createArrayST comp sz action = runST $ createArrayS comp sz action
{-# INLINE createArrayST #-}
generateArrayS ::
forall r ix e m. (Mutable r ix e, PrimMonad m)
=> Sz ix
-> (ix -> m e)
-> m (Array r ix e)
generateArrayS sz gen = generateArrayLinearS sz (gen . fromLinearIndex sz)
{-# INLINE generateArrayS #-}
generateArrayLinearS ::
forall r ix e m. (Mutable r ix e, PrimMonad m)
=> Sz ix
-> (Int -> m e)
-> m (Array r ix e)
generateArrayLinearS sz gen = do
marr <- unsafeNew sz
loopM_ 0 (< totalElem (msize marr)) (+ 1) $ \i -> gen i >>= unsafeLinearWrite marr i
unsafeFreeze Seq marr
{-# INLINE generateArrayLinearS #-}
generateArray ::
forall r ix e m. (MonadUnliftIO m, PrimMonad m, Mutable r ix e)
=> Comp
-> Sz ix
-> (ix -> m e)
-> m (Array r ix e)
generateArray comp sz f = generateArrayLinear comp sz (f . fromLinearIndex sz)
{-# INLINE generateArray #-}
generateArrayLinear ::
forall r ix e m. (MonadUnliftIO m, PrimMonad m, Mutable r ix e)
=> Comp
-> Sz ix
-> (Int -> m e)
-> m (Array r ix e)
generateArrayLinear comp sz f = makeMArrayLinear comp sz f >>= unsafeFreeze comp
{-# INLINE generateArrayLinear #-}
generateArrayLinearWS ::
forall r ix e s m. (Mutable r ix e, MonadUnliftIO m, PrimMonad m)
=> WorkerStates s
-> Sz ix
-> (Int -> s -> m e)
-> m (Array r ix e)
generateArrayLinearWS states sz make = do
marr <- unsafeNew sz
withSchedulerWS_ states $ \schedulerWS ->
splitLinearlyWithStatefulM_
schedulerWS
(totalElem sz)
make
(unsafeLinearWrite marr)
unsafeFreeze (workerStatesComp states) marr
{-# INLINE generateArrayLinearWS #-}
generateArrayWS ::
forall r ix e s m. (Mutable r ix e, MonadUnliftIO m, PrimMonad m)
=> WorkerStates s
-> Sz ix
-> (ix -> s -> m e)
-> m (Array r ix e)
generateArrayWS states sz make =
generateArrayLinearWS states sz (\ix -> make (fromLinearIndex sz ix))
{-# INLINE generateArrayWS #-}
unfoldrPrimM_ ::
forall r ix e a m. (Mutable r ix e, PrimMonad m)
=> Comp
-> Sz ix
-> (a -> m (e, a))
-> a
-> m (Array r ix e)
unfoldrPrimM_ comp sz gen acc0 = snd <$> unfoldrPrimM comp sz gen acc0
{-# INLINE unfoldrPrimM_ #-}
iunfoldrPrimM_ ::
forall r ix e a m. (Mutable r ix e, PrimMonad m)
=> Comp
-> Sz ix
-> (a -> ix -> m (e, a))
-> a
-> m (Array r ix e)
iunfoldrPrimM_ comp sz gen acc0 = snd <$> iunfoldrPrimM comp sz gen acc0
{-# INLINE iunfoldrPrimM_ #-}
iunfoldrPrimM ::
forall r ix e a m. (Mutable r ix e, PrimMonad m)
=> Comp
-> Sz ix
-> (a -> ix -> m (e, a))
-> a
-> m (a, Array r ix e)
iunfoldrPrimM comp sz gen acc0 =
createArrayS comp sz $ \marr ->
let sz' = msize marr
in iterLinearM sz' 0 (totalElem sz') 1 (<) acc0 $ \i ix acc -> do
(e, acc') <- gen acc ix
unsafeLinearWrite marr i e
pure $! acc'
{-# INLINE iunfoldrPrimM #-}
unfoldrPrimM ::
forall r ix e a m. (Mutable r ix e, PrimMonad m)
=> Comp
-> Sz ix
-> (a -> m (e, a))
-> a
-> m (a, Array r ix e)
unfoldrPrimM comp sz gen acc0 =
createArrayS comp sz $ \marr ->
let sz' = msize marr
in loopM 0 (< totalElem sz') (+1) acc0 $ \i acc -> do
(e, acc') <- gen acc
unsafeLinearWrite marr i e
pure $! acc'
{-# INLINE unfoldrPrimM #-}
unfoldlPrimM_ ::
forall r ix e a m. (Mutable r ix e, PrimMonad m)
=> Comp
-> Sz ix
-> (a -> m (a, e))
-> a
-> m (Array r ix e)
unfoldlPrimM_ comp sz gen acc0 = snd <$> unfoldlPrimM comp sz gen acc0
{-# INLINE unfoldlPrimM_ #-}
iunfoldlPrimM_ ::
forall r ix e a m. (Mutable r ix e, PrimMonad m)
=> Comp
-> Sz ix
-> (a -> ix -> m (a, e))
-> a
-> m (Array r ix e)
iunfoldlPrimM_ comp sz gen acc0 = snd <$> iunfoldlPrimM comp sz gen acc0
{-# INLINE iunfoldlPrimM_ #-}
iunfoldlPrimM ::
forall r ix e a m. (Mutable r ix e, PrimMonad m)
=> Comp
-> Sz ix
-> (a -> ix -> m (a, e))
-> a
-> m (a, Array r ix e)
iunfoldlPrimM comp sz gen acc0 =
createArrayS comp sz $ \marr ->
let sz' = msize marr
in iterLinearM sz' (totalElem sz' - 1) 0 (negate 1) (>=) acc0 $ \i ix acc -> do
(acc', e) <- gen acc ix
unsafeLinearWrite marr i e
pure $! acc'
{-# INLINE iunfoldlPrimM #-}
unfoldlPrimM ::
forall r ix e a m. (Mutable r ix e, PrimMonad m)
=> Comp
-> Sz ix
-> (a -> m (a, e))
-> a
-> m (a, Array r ix e)
unfoldlPrimM comp sz gen acc0 =
createArrayS comp sz $ \marr ->
let sz' = msize marr
in loopDeepM 0 (< totalElem sz') (+1) acc0 $ \i acc -> do
(acc', e) <- gen acc
unsafeLinearWrite marr i e
pure $! acc'
{-# INLINE unfoldlPrimM #-}
forPrimM_ :: (Mutable r ix e, PrimMonad m) => MArray (PrimState m) r ix e -> (e -> m ()) -> m ()
forPrimM_ marr f =
loopM_ 0 (< totalElem (msize marr)) (+1) (unsafeLinearRead marr >=> f)
{-# INLINE forPrimM_ #-}
forPrimM :: (Mutable r ix e, PrimMonad m) => MArray (PrimState m) r ix e -> (e -> m e) -> m ()
forPrimM marr f =
loopM_ 0 (< totalElem (msize marr)) (+1) (unsafeLinearModify marr f)
{-# INLINE forPrimM #-}
iforPrimM_ ::
(Mutable r ix e, PrimMonad m) => MArray (PrimState m) r ix e -> (ix -> e -> m ()) -> m ()
iforPrimM_ marr f = iforLinearPrimM_ marr (f . fromLinearIndex (msize marr))
{-# INLINE iforPrimM_ #-}
iforPrimM ::
(Mutable r ix e, PrimMonad m) => MArray (PrimState m) r ix e -> (ix -> e -> m e) -> m ()
iforPrimM marr f = iforLinearPrimM marr (f . fromLinearIndex (msize marr))
{-# INLINE iforPrimM #-}
iforLinearPrimM_ ::
(Mutable r ix e, PrimMonad m) => MArray (PrimState m) r ix e -> (Int -> e -> m ()) -> m ()
iforLinearPrimM_ marr f =
loopM_ 0 (< totalElem (msize marr)) (+ 1) (\i -> unsafeLinearRead marr i >>= f i)
{-# INLINE iforLinearPrimM_ #-}
iforLinearPrimM ::
(Mutable r ix e, PrimMonad m) => MArray (PrimState m) r ix e -> (Int -> e -> m e) -> m ()
iforLinearPrimM marr f =
loopM_ 0 (< totalElem (msize marr)) (+ 1) (\i -> unsafeLinearModify marr (f i) i)
{-# INLINE iforLinearPrimM #-}
withMArray ::
(Mutable r ix e, MonadUnliftIO m)
=> Array r ix e
-> (Scheduler m () -> MArray RealWorld r ix e -> m a)
-> m (Array r ix e)
withMArray arr action = do
marr <- thaw arr
withScheduler_ (getComp arr) (`action` marr)
liftIO $ unsafeFreeze (getComp arr) marr
{-# INLINE withMArray #-}
withMArrayS ::
(Mutable r ix e, PrimMonad m)
=> Array r ix e
-> (MArray (PrimState m) r ix e -> m a)
-> m (Array r ix e)
withMArrayS arr action = do
marr <- thawS arr
_ <- action marr
unsafeFreeze (getComp arr) marr
{-# INLINE withMArrayS #-}
withMArrayST ::
Mutable r ix e
=> Array r ix e
-> (forall s . MArray s r ix e -> ST s a)
-> Array r ix e
withMArrayST arr f = runST $ withMArrayS arr f
{-# INLINE withMArrayST #-}
read :: (Mutable r ix e, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m (Maybe e)
read marr ix =
if isSafeIndex (msize marr) ix
then Just <$> unsafeRead marr ix
else return Nothing
{-# INLINE read #-}
readM :: (Mutable r ix e, PrimMonad m, MonadThrow m) =>
MArray (PrimState m) r ix e -> ix -> m e
readM marr ix =
read marr ix >>= \case
Just e -> pure e
Nothing -> throwM $ IndexOutOfBoundsException (msize marr) ix
{-# INLINE readM #-}
read' :: (Mutable r ix e, PrimMonad m) => MArray (PrimState m) r ix e -> ix -> m e
read' marr ix =
read marr ix >>= \case
Just e -> pure e
Nothing -> throw $ IndexOutOfBoundsException (msize marr) ix
{-# INLINE read' #-}
{-# DEPRECATED read' "In favor of more general `readM`" #-}
write :: (Mutable r ix e, PrimMonad m) => MArray (PrimState m) r ix e -> ix -> e -> m Bool
write marr ix e =
if isSafeIndex (msize marr) ix
then unsafeWrite marr ix e >> pure True
else pure False
{-# INLINE write #-}
writeM ::
(Mutable r ix e, PrimMonad m, MonadThrow m) => MArray (PrimState m) r ix e -> ix -> e -> m ()
writeM marr ix e =
write marr ix e >>= (`unless` throwM (IndexOutOfBoundsException (msize marr) ix))
{-# INLINE writeM #-}
write' ::
(Mutable r ix e, PrimMonad m) => MArray (PrimState m) r ix e -> ix -> e -> m ()
write' marr ix e = write marr ix e >>= (`unless` throw (IndexOutOfBoundsException (msize marr) ix))
{-# INLINE write' #-}
{-# DEPRECATED write' "In favor of more general `writeM`" #-}
modify ::
(Mutable r ix e, PrimMonad m)
=> MArray (PrimState m) r ix e
-> (e -> m e)
-> ix
-> m (Maybe e)
modify marr f ix =
if isSafeIndex (msize marr) ix
then Just <$> unsafeModify marr f ix
else return Nothing
{-# INLINE modify #-}
modifyM ::
(Mutable r ix e, PrimMonad m, MonadThrow m)
=> MArray (PrimState m) r ix e
-> (e -> m e)
-> ix
-> m e
modifyM marr f ix
| isSafeIndex (msize marr) ix = unsafeModify marr f ix
| otherwise = throwM (IndexOutOfBoundsException (msize marr) ix)
{-# INLINE modifyM #-}
modifyM_ ::
(Mutable r ix e, PrimMonad m, MonadThrow m)
=> MArray (PrimState m) r ix e
-> (e -> m e)
-> ix
-> m ()
modifyM_ marr f ix = void $ modifyM marr f ix
{-# INLINE modifyM_ #-}
modify' :: (Mutable r ix e, PrimMonad m) =>
MArray (PrimState m) r ix e -> (e -> e) -> ix -> m ()
modify' marr f ix =
modify marr (pure . f) ix >>= \case
Just _ -> pure ()
Nothing -> throw (IndexOutOfBoundsException (msize marr) ix)
{-# INLINE modify' #-}
{-# DEPRECATED modify' "In favor of more general `modifyM`" #-}
swap :: (Mutable r ix e, PrimMonad m) => MArray (PrimState m) r ix e -> ix -> ix -> m (Maybe (e, e))
swap marr ix1 ix2 =
let sz = msize marr
in if isSafeIndex sz ix1 && isSafeIndex sz ix2
then Just <$> unsafeSwap marr ix1 ix2
else pure Nothing
{-# INLINE swap #-}
swapM ::
(Mutable r ix e, PrimMonad m, MonadThrow m)
=> MArray (PrimState m) r ix e
-> ix
-> ix
-> m (e, e)
swapM marr ix1 ix2
| not (isSafeIndex sz ix1) = throwM $ IndexOutOfBoundsException (msize marr) ix1
| not (isSafeIndex sz ix2) = throwM $ IndexOutOfBoundsException (msize marr) ix2
| otherwise = unsafeSwap marr ix1 ix2
where
!sz = msize marr
{-# INLINE swapM #-}
swapM_ ::
(Mutable r ix e, PrimMonad m, MonadThrow m) => MArray (PrimState m) r ix e -> ix -> ix -> m ()
swapM_ marr ix1 ix2 = void $ swapM marr ix1 ix2
{-# INLINE swapM_ #-}
swap' ::
(Mutable r ix e, PrimMonad m) => MArray (PrimState m) r ix e -> ix -> ix -> m ()
swap' marr ix1 ix2 =
swap marr ix1 ix2 >>= \case
Just _ -> pure ()
Nothing ->
if isSafeIndex (msize marr) ix1
then throw $ IndexOutOfBoundsException (msize marr) ix2
else throw $ IndexOutOfBoundsException (msize marr) ix1
{-# INLINE swap' #-}
{-# DEPRECATED swap' "In favor of more general `swapM`" #-}