{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Data.Massiv.Array.Ops.Map
( map
, imap
, traverseA
, traverseA_
, itraverseA
, itraverseA_
, traverseAR
, itraverseAR
, sequenceA
, sequenceA_
, traversePrim
, itraversePrim
, traversePrimR
, itraversePrimR
, mapM
, mapMR
, forM
, forMR
, imapM
, imapMR
, iforM
, iforMR
, mapM_
, forM_
, imapM_
, iforM_
, mapIO
, mapWS
, mapIO_
, imapIO
, imapWS
, imapIO_
, forIO
, forWS
, forIO_
, iforIO
, iforWS
, iforIO_
, imapSchedulerM_
, iforSchedulerM_
, zip
, zip3
, unzip
, unzip3
, zipWith
, zipWith3
, izipWith
, izipWith3
, liftArray2
, zipWithA
, izipWithA
, zipWith3A
, izipWith3A
) where
import Control.Monad (void)
import Control.Monad.Primitive (PrimMonad)
import Control.Scheduler
import Data.Coerce
import Data.Massiv.Array.Delayed.Pull
import Data.Massiv.Array.Mutable
import Data.Massiv.Array.Ops.Construct (makeArrayA)
import Data.Massiv.Core.Common
import Data.Massiv.Core.Index.Internal (Sz(..))
import Prelude hiding (map, mapM, mapM_, sequenceA, traverse, unzip, unzip3,
zip, zip3, zipWith, zipWith3)
map :: Source r ix e' => (e' -> e) -> Array r ix e' -> Array D ix e
map f = imap (const f)
{-# INLINE map #-}
imap :: Source r ix e' => (ix -> e' -> e) -> Array r ix e' -> Array D ix e
imap f !arr = DArray (getComp arr) (size arr) (\ !ix -> f ix (unsafeIndex arr ix))
{-# INLINE imap #-}
zip :: (Source r1 ix e1, Source r2 ix e2)
=> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix (e1, e2)
zip = zipWith (,)
{-# INLINE zip #-}
zip3 :: (Source r1 ix e1, Source r2 ix e2, Source r3 ix e3)
=> Array r1 ix e1 -> Array r2 ix e2 -> Array r3 ix e3 -> Array D ix (e1, e2, e3)
zip3 = zipWith3 (,,)
{-# INLINE zip3 #-}
unzip :: Source r ix (e1, e2) => Array r ix (e1, e2) -> (Array D ix e1, Array D ix e2)
unzip arr = (map fst arr, map snd arr)
{-# INLINE unzip #-}
unzip3 :: Source r ix (e1, e2, e3)
=> Array r ix (e1, e2, e3) -> (Array D ix e1, Array D ix e2, Array D ix e3)
unzip3 arr = (map (\ (e, _, _) -> e) arr, map (\ (_, e, _) -> e) arr, map (\ (_, _, e) -> e) arr)
{-# INLINE unzip3 #-}
zipWith :: (Source r1 ix e1, Source r2 ix e2)
=> (e1 -> e2 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix e
zipWith f = izipWith (\ _ e1 e2 -> f e1 e2)
{-# INLINE zipWith #-}
izipWith :: (Source r1 ix e1, Source r2 ix e2)
=> (ix -> e1 -> e2 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix e
izipWith f arr1 arr2 =
DArray
(getComp arr1 <> getComp arr2)
(SafeSz (liftIndex2 min (coerce (size arr1)) (coerce (size arr2)))) $ \ !ix ->
f ix (unsafeIndex arr1 ix) (unsafeIndex arr2 ix)
{-# INLINE izipWith #-}
zipWith3 :: (Source r1 ix e1, Source r2 ix e2, Source r3 ix e3)
=> (e1 -> e2 -> e3 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array r3 ix e3 -> Array D ix e
zipWith3 f = izipWith3 (\ _ e1 e2 e3 -> f e1 e2 e3)
{-# INLINE zipWith3 #-}
izipWith3
:: (Source r1 ix e1, Source r2 ix e2, Source r3 ix e3)
=> (ix -> e1 -> e2 -> e3 -> e)
-> Array r1 ix e1
-> Array r2 ix e2
-> Array r3 ix e3
-> Array D ix e
izipWith3 f arr1 arr2 arr3 =
DArray
(getComp arr1 <> getComp arr2 <> getComp arr3)
(SafeSz
(liftIndex2
min
(liftIndex2 min (coerce (size arr1)) (coerce (size arr2)))
(coerce (size arr3)))) $ \ !ix ->
f ix (unsafeIndex arr1 ix) (unsafeIndex arr2 ix) (unsafeIndex arr3 ix)
{-# INLINE izipWith3 #-}
zipWithA ::
(Source r1 ix e1, Source r2 ix e2, Applicative f, Mutable r ix e)
=> (e1 -> e2 -> f e)
-> Array r1 ix e1
-> Array r2 ix e2
-> f (Array r ix e)
zipWithA f = izipWithA (const f)
{-# INLINE zipWithA #-}
izipWithA ::
(Source r1 ix e1, Source r2 ix e2, Applicative f, Mutable r ix e)
=> (ix -> e1 -> e2 -> f e)
-> Array r1 ix e1
-> Array r2 ix e2
-> f (Array r ix e)
izipWithA f arr1 arr2 =
setComp (getComp arr1 <> getComp arr2) <$>
makeArrayA
(SafeSz (liftIndex2 min (coerce (size arr1)) (coerce (size arr2))))
(\ !ix -> f ix (unsafeIndex arr1 ix) (unsafeIndex arr2 ix))
{-# INLINE izipWithA #-}
zipWith3A ::
(Source r1 ix e1, Source r2 ix e2, Source r3 ix e3, Applicative f, Mutable r ix e)
=> (e1 -> e2 -> e3 -> f e)
-> Array r1 ix e1
-> Array r2 ix e2
-> Array r3 ix e3
-> f (Array r ix e)
zipWith3A f = izipWith3A (const f)
{-# INLINE zipWith3A #-}
izipWith3A ::
(Source r1 ix e1, Source r2 ix e2, Source r3 ix e3, Applicative f, Mutable r ix e)
=> (ix -> e1 -> e2 -> e3 -> f e)
-> Array r1 ix e1
-> Array r2 ix e2
-> Array r3 ix e3
-> f (Array r ix e)
izipWith3A f arr1 arr2 arr3 =
setComp (getComp arr1 <> getComp arr2 <> getComp arr3) <$>
makeArrayA sz (\ !ix -> f ix (unsafeIndex arr1 ix) (unsafeIndex arr2 ix) (unsafeIndex arr3 ix))
where
sz =
SafeSz $
liftIndex2 min (liftIndex2 min (coerce (size arr1)) (coerce (size arr2))) (coerce (size arr3))
{-# INLINE izipWith3A #-}
liftArray2
:: (Source r1 ix a, Source r2 ix b)
=> (a -> b -> e) -> Array r1 ix a -> Array r2 ix b -> Array D ix e
liftArray2 f !arr1 !arr2
| sz1 == oneSz = map (f (unsafeIndex arr1 zeroIndex)) arr2
| sz2 == oneSz = map (`f` unsafeIndex arr2 zeroIndex) arr1
| sz1 == sz2 =
DArray (getComp arr1 <> getComp arr2) sz1 (\ !ix -> f (unsafeIndex arr1 ix) (unsafeIndex arr2 ix))
| otherwise = throw $ SizeMismatchException (size arr1) (size arr2)
where
sz1 = size arr1
sz2 = size arr2
{-# INLINE liftArray2 #-}
traverseA ::
forall r ix e r' a f . (Source r' ix a, Mutable r ix e, Applicative f)
=> (a -> f e)
-> Array r' ix a
-> f (Array r ix e)
traverseA f arr = setComp (getComp arr) <$> makeArrayA (size arr) (f . unsafeIndex arr)
{-# INLINE traverseA #-}
traverseA_ :: forall r ix e a f . (Source r ix e, Applicative f) => (e -> f a) -> Array r ix e -> f ()
traverseA_ f arr = loopA_ 0 (< totalElem (size arr)) (+ 1) (f . unsafeLinearIndex arr)
{-# INLINE traverseA_ #-}
sequenceA ::
forall r ix e r' f. (Source r' ix (f e), Mutable r ix e, Applicative f)
=> Array r' ix (f e)
-> f (Array r ix e)
sequenceA = traverseA id
{-# INLINE sequenceA #-}
sequenceA_ :: forall r ix e f . (Source r ix (f e), Applicative f) => Array r ix (f e) -> f ()
sequenceA_ = traverseA_ id
{-# INLINE sequenceA_ #-}
itraverseA ::
forall r ix e r' a f . (Source r' ix a, Mutable r ix e, Applicative f)
=> (ix -> a -> f e)
-> Array r' ix a
-> f (Array r ix e)
itraverseA f arr =
setComp (getComp arr) <$> makeArrayA (size arr) (\ !ix -> f ix (unsafeIndex arr ix))
{-# INLINE itraverseA #-}
itraverseA_ ::
forall r ix e a f. (Source r ix a, Applicative f)
=> (ix -> a -> f e)
-> Array r ix a
-> f ()
itraverseA_ f arr =
loopA_ 0 (< totalElem sz) (+ 1) (\ !i -> f (fromLinearIndex sz i) (unsafeLinearIndex arr i))
where
sz = size arr
{-# INLINE itraverseA_ #-}
traverseAR ::
(Source r' ix a, Mutable r ix b, Applicative f)
=> r
-> (a -> f b)
-> Array r' ix a
-> f (Array r ix b)
traverseAR _ = traverseA
{-# INLINE traverseAR #-}
{-# DEPRECATED traverseAR "In favor of `traverseA`" #-}
itraverseAR ::
(Source r' ix a, Mutable r ix b, Applicative f)
=> r
-> (ix -> a -> f b)
-> Array r' ix a
-> f (Array r ix b)
itraverseAR _ = itraverseA
{-# INLINE itraverseAR #-}
{-# DEPRECATED itraverseAR "In favor of `itraverseA`" #-}
traversePrim ::
forall r ix b r' a m . (Source r' ix a, Mutable r ix b, PrimMonad m)
=> (a -> m b)
-> Array r' ix a
-> m (Array r ix b)
traversePrim f = itraversePrim (const f)
{-# INLINE traversePrim #-}
itraversePrim ::
forall r ix b r' a m . (Source r' ix a, Mutable r ix b, PrimMonad m)
=> (ix -> a -> m b)
-> Array r' ix a
-> m (Array r ix b)
itraversePrim f arr =
setComp (getComp arr) <$>
generateArrayLinearS
(size arr)
(\ !i ->
let ix = fromLinearIndex (size arr) i
in f ix (unsafeLinearIndex arr i))
{-# INLINE itraversePrim #-}
traversePrimR ::
(Source r' ix a, Mutable r ix b, PrimMonad m)
=> r
-> (a -> m b)
-> Array r' ix a
-> m (Array r ix b)
traversePrimR _ = traversePrim
{-# INLINE traversePrimR #-}
{-# DEPRECATED traversePrimR "In favor of `traversePrim`" #-}
itraversePrimR ::
(Source r' ix a, Mutable r ix b, PrimMonad m)
=> r
-> (ix -> a -> m b)
-> Array r' ix a
-> m (Array r ix b)
itraversePrimR _ = itraversePrim
{-# INLINE itraversePrimR #-}
{-# DEPRECATED itraversePrimR "In favor of `itraversePrim`" #-}
mapM ::
forall r ix b r' a m. (Source r' ix a, Mutable r ix b, Monad m)
=> (a -> m b)
-> Array r' ix a
-> m (Array r ix b)
mapM = traverseA
{-# INLINE mapM #-}
mapMR ::
forall r ix b r' a m. (Source r' ix a, Mutable r ix b, Monad m)
=> r
-> (a -> m b)
-> Array r' ix a
-> m (Array r ix b)
mapMR _ = traverseA
{-# INLINE mapMR #-}
forM ::
forall r ix b r' a m. (Source r' ix a, Mutable r ix b, Monad m)
=> Array r' ix a
-> (a -> m b)
-> m (Array r ix b)
forM = flip traverseA
{-# INLINE forM #-}
forMR ::
forall r ix b r' a m. (Source r' ix a, Mutable r ix b, Monad m)
=> r
-> Array r' ix a
-> (a -> m b)
-> m (Array r ix b)
forMR _ = flip traverseA
{-# INLINE forMR #-}
imapM ::
forall r ix b r' a m. (Source r' ix a, Mutable r ix b, Monad m)
=> (ix -> a -> m b)
-> Array r' ix a
-> m (Array r ix b)
imapM = itraverseA
{-# INLINE imapM #-}
imapMR ::
forall r ix b r' a m. (Source r' ix a, Mutable r ix b, Monad m)
=> r
-> (ix -> a -> m b)
-> Array r' ix a
-> m (Array r ix b)
imapMR _ = itraverseA
{-# INLINE imapMR #-}
iforM ::
forall r ix b r' a m. (Source r' ix a, Mutable r ix b, Monad m)
=> (ix -> a -> m b)
-> Array r' ix a
-> m (Array r ix b)
iforM = itraverseA
{-# INLINE iforM #-}
iforMR ::
forall r ix b r' a m. (Source r' ix a, Mutable r ix b, Monad m)
=> r
-> (ix -> a -> m b)
-> Array r' ix a
-> m (Array r ix b)
iforMR _ = itraverseA
{-# INLINE iforMR #-}
mapM_ :: (Source r ix a, Monad m) => (a -> m b) -> Array r ix a -> m ()
mapM_ f !arr = iterM_ zeroIndex (unSz (size arr)) (pureIndex 1) (<) (f . unsafeIndex arr)
{-# INLINE mapM_ #-}
forM_ :: (Source r ix a, Monad m) => Array r ix a -> (a -> m b) -> m ()
forM_ = flip mapM_
{-# INLINE forM_ #-}
iforM_ :: (Source r ix a, Monad m) => Array r ix a -> (ix -> a -> m b) -> m ()
iforM_ = flip imapM_
{-# INLINE iforM_ #-}
mapIO ::
forall r ix b r' a m. (Source r' ix a, Mutable r ix b, MonadUnliftIO m, PrimMonad m)
=> (a -> m b)
-> Array r' ix a
-> m (Array r ix b)
mapIO action = imapIO (const action)
{-# INLINE mapIO #-}
mapIO_ :: (Source r b e, MonadUnliftIO m) => (e -> m a) -> Array r b e -> m ()
mapIO_ action = imapIO_ (const action)
{-# INLINE mapIO_ #-}
imapIO_ :: (Source r ix e, MonadUnliftIO m) => (ix -> e -> m a) -> Array r ix e -> m ()
imapIO_ action arr =
withScheduler_ (getComp arr) $ \scheduler -> imapSchedulerM_ scheduler action arr
{-# INLINE imapIO_ #-}
imapSchedulerM_ ::
(Source r ix e, Monad m) => Scheduler m () -> (ix -> e -> m a) -> Array r ix e -> m ()
imapSchedulerM_ scheduler action arr = do
let sz = size arr
splitLinearlyWith_
scheduler
(totalElem sz)
(unsafeLinearIndex arr)
(\i -> void . action (fromLinearIndex sz i))
{-# INLINE imapSchedulerM_ #-}
iforSchedulerM_ ::
(Source r ix e, Monad m) => Scheduler m () -> Array r ix e -> (ix -> e -> m a) -> m ()
iforSchedulerM_ scheduler arr action = imapSchedulerM_ scheduler action arr
{-# INLINE iforSchedulerM_ #-}
imapIO ::
forall r ix b r' a m. (Source r' ix a, Mutable r ix b, MonadUnliftIO m, PrimMonad m)
=> (ix -> a -> m b)
-> Array r' ix a
-> m (Array r ix b)
imapIO action arr = generateArray (getComp arr) (size arr) $ \ix -> action ix (unsafeIndex arr ix)
{-# INLINE imapIO #-}
forIO ::
forall r ix b r' a m. (Source r' ix a, Mutable r ix b, MonadUnliftIO m, PrimMonad m)
=> Array r' ix a
-> (a -> m b)
-> m (Array r ix b)
forIO = flip mapIO
{-# INLINE forIO #-}
imapWS ::
forall r ix b r' a s m. (Source r' ix a, Mutable r ix b, MonadUnliftIO m, PrimMonad m)
=> WorkerStates s
-> (ix -> a -> s -> m b)
-> Array r' ix a
-> m (Array r ix b)
imapWS states f arr = generateArrayWS states (size arr) (\ix s -> f ix (unsafeIndex arr ix) s)
{-# INLINE imapWS #-}
mapWS ::
forall r ix b r' a s m. (Source r' ix a, Mutable r ix b, MonadUnliftIO m, PrimMonad m)
=> WorkerStates s
-> (a -> s -> m b)
-> Array r' ix a
-> m (Array r ix b)
mapWS states f = imapWS states (\ _ -> f)
{-# INLINE mapWS #-}
iforWS ::
forall r ix b r' a s m. (Source r' ix a, Mutable r ix b, MonadUnliftIO m, PrimMonad m)
=> WorkerStates s
-> Array r' ix a
-> (ix -> a -> s -> m b)
-> m (Array r ix b)
iforWS states f arr = imapWS states arr f
{-# INLINE iforWS #-}
forWS ::
forall r ix b r' a s m. (Source r' ix a, Mutable r ix b, MonadUnliftIO m, PrimMonad m)
=> WorkerStates s
-> Array r' ix a
-> (a -> s -> m b)
-> m (Array r ix b)
forWS states arr f = imapWS states (\ _ -> f) arr
{-# INLINE forWS #-}
forIO_ :: (Source r ix e, MonadUnliftIO m) => Array r ix e -> (e -> m a) -> m ()
forIO_ = flip mapIO_
{-# INLINE forIO_ #-}
iforIO ::
forall r ix b r' a m. (Source r' ix a, Mutable r ix b, MonadUnliftIO m, PrimMonad m)
=> Array r' ix a
-> (ix -> a -> m b)
-> m (Array r ix b)
iforIO = flip imapIO
{-# INLINE iforIO #-}
iforIO_ :: (Source r ix a, MonadUnliftIO m) => Array r ix a -> (ix -> a -> m b) -> m ()
iforIO_ = flip imapIO_
{-# INLINE iforIO_ #-}