{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ExplicitForAll #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -fno-warn-redundant-constraints #-}
module Data.Massiv.Array.Ops.Transform
(
transpose
, transposeInner
, transposeOuter
, backpermuteM
, backpermute'
, backpermute
, resizeM
, resize'
, resize
, flatten
, extractM
, extract
, extract'
, extractFromToM
, extractFromTo
, extractFromTo'
, cons
, unconsM
, snoc
, unsnocM
, appendM
, append
, append'
, concatM
, concat'
, splitAtM
, splitAt
, splitAt'
, upsample
, downsample
, zoomWithGrid
, transformM
, transform'
, transform2M
, transform2'
, traverse
, traverse2
) where
import Control.Scheduler (traverse_)
import Control.Monad as M (foldM_, unless)
import Data.Bifunctor (bimap)
import Data.Foldable as F (foldl', foldrM, toList)
import qualified Data.List as L (uncons)
import Data.Massiv.Array.Delayed.Pull
import Data.Massiv.Array.Delayed.Push
import Data.Massiv.Array.Mutable
import Data.Massiv.Array.Ops.Construct
import Data.Massiv.Array.Ops.Map
import Data.Massiv.Core.Common
import Data.Massiv.Core.Index.Internal (Sz(SafeSz))
import Prelude as P hiding (concat, splitAt, traverse, mapM_)
extractM :: (MonadThrow m, Extract r ix e)
=> ix
-> Sz ix
-> Array r ix e
-> m (Array (EltRepr r ix) ix e)
extractM !sIx !newSz !arr
| isSafeIndex sz1 sIx && isSafeIndex eIx1 sIx && isSafeIndex sz1 eIx =
pure $ unsafeExtract sIx newSz arr
| otherwise = throwM $ SizeSubregionException (size arr) sIx newSz
where
sz1 = Sz (liftIndex (+1) (unSz (size arr)))
eIx1 = Sz (liftIndex (+1) eIx)
eIx = liftIndex2 (+) sIx $ unSz newSz
{-# INLINE extractM #-}
extract' :: Extract r ix e
=> ix
-> Sz ix
-> Array r ix e
-> Array (EltRepr r ix) ix e
extract' sIx newSz = either throw id . extractM sIx newSz
{-# INLINE extract' #-}
extract :: Extract r ix e
=> ix
-> Sz ix
-> Array r ix e
-> Maybe (Array (EltRepr r ix) ix e)
extract !sIx !newSz !arr
| isSafeIndex sz1 sIx && isSafeIndex eIx1 sIx && isSafeIndex sz1 eIx =
Just $ unsafeExtract sIx newSz arr
| otherwise = Nothing
where
sz1 = Sz (liftIndex (+1) (unSz (size arr)))
eIx1 = Sz (liftIndex (+1) eIx)
eIx = liftIndex2 (+) sIx $ unSz newSz
{-# INLINE extract #-}
{-# DEPRECATED extract "In favor of a more general `extractM`" #-}
extractFromToM :: (MonadThrow m, Extract r ix e) =>
ix
-> ix
-> Array r ix e
-> m (Array (EltRepr r ix) ix e)
extractFromToM sIx eIx = extractM sIx (Sz (liftIndex2 (-) eIx sIx))
{-# INLINE extractFromToM #-}
extractFromTo :: Extract r ix e =>
ix
-> ix
-> Array r ix e
-> Maybe (Array (EltRepr r ix) ix e)
extractFromTo sIx eIx = extract sIx $ Sz (liftIndex2 (-) eIx sIx)
{-# INLINE extractFromTo #-}
{-# DEPRECATED extractFromTo "In favor of a more general `extractFromToM`" #-}
extractFromTo' :: Extract r ix e =>
ix
-> ix
-> Array r ix e
-> Array (EltRepr r ix) ix e
extractFromTo' sIx eIx = extract' sIx $ Sz (liftIndex2 (-) eIx sIx)
{-# INLINE extractFromTo' #-}
resize ::
(Index ix', Load r ix e, Resize r ix) => Sz ix' -> Array r ix e -> Maybe (Array r ix' e)
resize !sz !arr
| totalElem sz == totalElem (size arr) = Just $ unsafeResize sz arr
| otherwise = Nothing
{-# INLINE resize #-}
{-# DEPRECATED resize "In favor of a more general `resizeM`" #-}
resizeM ::
(MonadThrow m, Index ix', Load r ix e, Resize r ix)
=> Sz ix'
-> Array r ix e
-> m (Array r ix' e)
resizeM sz arr = guardNumberOfElements (size arr) sz >> pure (unsafeResize sz arr)
{-# INLINE resizeM #-}
resize' :: (Index ix', Load r ix e, Resize r ix) => Sz ix' -> Array r ix e -> Array r ix' e
resize' sz = either throw id . resizeM sz
{-# INLINE resize' #-}
flatten :: (Load r ix e, Resize r ix) => Array r ix e -> Array r Ix1 e
flatten arr = unsafeResize (SafeSz (totalElem (size arr))) arr
{-# INLINE flatten #-}
transpose :: Source r Ix2 e => Array r Ix2 e -> Array D Ix2 e
transpose = transposeInner
{-# INLINE [1] transpose #-}
{-# RULES
"transpose . transpose" [~1] forall arr . transpose (transpose arr) = delay arr
"transposeInner . transposeInner" [~1] forall arr . transposeInner (transposeInner arr) = delay arr
"transposeOuter . transposeOuter" [~1] forall arr . transposeOuter (transposeOuter arr) = delay arr
#-}
transposeInner :: (Index (Lower ix), Source r' ix e)
=> Array r' ix e -> Array D ix e
transposeInner !arr = makeArray (getComp arr) newsz newVal
where
transInner !ix =
either throwImpossible id $ do
n <- getDimM ix dix
m <- getDimM ix (dix - 1)
ix' <- setDimM ix dix m
setDimM ix' (dix - 1) n
{-# INLINE transInner #-}
newVal = unsafeIndex arr . transInner
{-# INLINE newVal #-}
!newsz = Sz (transInner (unSz (size arr)))
!dix = dimensions newsz
{-# INLINE [1] transposeInner #-}
transposeOuter :: (Index (Lower ix), Source r' ix e)
=> Array r' ix e -> Array D ix e
transposeOuter !arr = makeArray (getComp arr) newsz newVal
where
transOuter !ix =
either throwImpossible id $ do
n <- getDimM ix 1
m <- getDimM ix 2
ix' <- setDimM ix 1 m
setDimM ix' 2 n
{-# INLINE transOuter #-}
newVal = unsafeIndex arr . transOuter
{-# INLINE newVal #-}
!newsz = Sz (transOuter (unSz (size arr)))
{-# INLINE [1] transposeOuter #-}
backpermuteM ::
forall r ix e r' ix' m.
(Mutable r ix e, Source r' ix' e, MonadUnliftIO m, PrimMonad m, MonadThrow m)
=> Sz ix
-> (ix -> ix')
-> Array r' ix' e
-> m (Array r ix e)
backpermuteM sz ixF !arr = generateArray (getComp arr) sz (evaluateM arr . ixF)
{-# INLINE backpermuteM #-}
backpermute' :: (Source r' ix' e, Index ix) =>
Sz ix
-> (ix -> ix')
-> Array r' ix' e
-> Array D ix e
backpermute' sz ixF !arr = makeArray (getComp arr) sz (evaluate' arr . ixF)
{-# INLINE backpermute' #-}
backpermute :: (Source r' ix' e, Index ix) =>
Sz ix
-> (ix -> ix')
-> Array r' ix' e
-> Array D ix e
backpermute = backpermute'
{-# INLINE backpermute #-}
{-# DEPRECATED backpermute "In favor of a safe `backpermuteM` or an equivalent `backpermute'`" #-}
cons :: e -> Array DL Ix1 e -> Array DL Ix1 e
cons e arr =
arr
{ dlSize = SafeSz (1 + unSz (dlSize arr))
, dlLoad =
\scheduler startAt uWrite ->
uWrite startAt e >> dlLoad arr scheduler (startAt + 1) uWrite
}
{-# INLINE cons #-}
unconsM :: (MonadThrow m, Source r Ix1 e) => Array r Ix1 e -> m (e, Array D Ix1 e)
unconsM arr
| 0 == totalElem sz = throwM $ SizeEmptyException sz
| otherwise =
pure
( unsafeLinearIndex arr 0
, makeArray (getComp arr) (SafeSz (unSz sz - 1)) (\ !i -> unsafeLinearIndex arr (i + 1)))
where
!sz = size arr
{-# INLINE unconsM #-}
snoc :: Array DL Ix1 e -> e -> Array DL Ix1 e
snoc arr e =
arr
{ dlSize = SafeSz (1 + k)
, dlLoad =
\scheduler startAt uWrite -> dlLoad arr scheduler startAt uWrite >> uWrite (k + startAt) e
}
where
!k = unSz (size arr)
{-# INLINE snoc #-}
unsnocM :: (MonadThrow m, Source r Ix1 e) => Array r Ix1 e -> m (Array D Ix1 e, e)
unsnocM arr
| 0 == totalElem sz = throwM $ SizeEmptyException sz
| otherwise =
pure (makeArray (getComp arr) (SafeSz k) (unsafeLinearIndex arr), unsafeLinearIndex arr k)
where
!sz = size arr
!k = unSz sz - 1
{-# INLINE unsnocM #-}
appendM :: (MonadThrow m, Source r1 ix e, Source r2 ix e) =>
Dim -> Array r1 ix e -> Array r2 ix e -> m (Array DL ix e)
appendM n !arr1 !arr2 = do
let !sz1 = size arr1
!sz2 = size arr2
(k1, szl1) <- pullOutSzM sz1 n
(k2, szl2) <- pullOutSzM sz2 n
unless (szl1 == szl2) $ throwM $ SizeMismatchException sz1 sz2
let k1' = unSz k1
newSz <- insertSzM szl1 n (SafeSz (k1' + unSz k2))
return $
DLArray
{ dlComp = getComp arr1 <> getComp arr2
, dlSize = newSz
, dlDefault = Nothing
, dlLoad =
\scheduler startAt dlWrite -> do
scheduleWork scheduler $
iterM_ zeroIndex (unSz sz1) (pureIndex 1) (<) $ \ix ->
dlWrite (startAt + toLinearIndex newSz ix) (unsafeIndex arr1 ix)
scheduleWork scheduler $
iterM_ zeroIndex (unSz sz2) (pureIndex 1) (<) $ \ix ->
let i = getDim' ix n
ix' = setDim' ix n (i + k1')
in dlWrite (startAt + toLinearIndex newSz ix') (unsafeIndex arr2 ix)
}
{-# INLINE appendM #-}
append :: (Source r1 ix e, Source r2 ix e) =>
Dim -> Array r1 ix e -> Array r2 ix e -> Maybe (Array DL ix e)
append = appendM
{-# INLINE append #-}
{-# DEPRECATED append "In favor of a more general `appendM`" #-}
append' :: (Source r1 ix e, Source r2 ix e) =>
Dim -> Array r1 ix e -> Array r2 ix e -> Array DL ix e
append' dim arr1 arr2 = either throw id $ appendM dim arr1 arr2
{-# INLINE append' #-}
concat' :: (Foldable f, Source r ix e) => Dim -> f (Array r ix e) -> Array DL ix e
concat' n arrs = either throw id $ concatM n arrs
{-# INLINE concat' #-}
concatM ::
(MonadThrow m, Foldable f, Source r ix e) => Dim -> f (Array r ix e) -> m (Array DL ix e)
concatM n !arrsF =
case L.uncons (F.toList arrsF) of
Nothing -> pure empty
Just (a, arrs) -> do
let sz = unSz (size a)
szs = P.map (unSz . size) arrs
(k, szl) <- pullOutDimM sz n
(ks, szls) <-
F.foldrM (\ !csz (ks, szls) -> bimap (: ks) (: szls) <$> pullOutDimM csz n) ([], []) szs
traverse_
(\(sz', _) -> throwM (SizeMismatchException (SafeSz sz) (SafeSz sz')))
(dropWhile ((== szl) . snd) $ P.zip szs szls)
let kTotal = SafeSz $ F.foldl' (+) k ks
newSz <- insertSzM (SafeSz szl) n kTotal
return $
DLArray
{ dlComp = mconcat $ P.map getComp arrs
, dlSize = newSz
, dlDefault = Nothing
, dlLoad =
\scheduler startAt dlWrite ->
let arrayLoader !kAcc (kCur, arr) = do
scheduleWork scheduler $
iterM_ zeroIndex (unSz (size arr)) (pureIndex 1) (<) $ \ix ->
let i = getDim' ix n
ix' = setDim' ix n (i + kAcc)
in dlWrite (startAt + toLinearIndex newSz ix') (unsafeIndex arr ix)
pure (kAcc + kCur)
in M.foldM_ arrayLoader 0 $ (k, a) : P.zip ks arrs
}
{-# INLINE concatM #-}
splitAtM ::
(MonadThrow m, Extract r ix e, r' ~ EltRepr r ix)
=> Dim
-> Int
-> Array r ix e
-> m (Array r' ix e, Array r' ix e)
splitAtM dim i arr = do
let Sz sz = size arr
eIx <- setDimM sz dim i
sIx <- setDimM zeroIndex dim i
arr1 <- extractFromToM zeroIndex eIx arr
arr2 <- extractFromToM sIx sz arr
return (arr1, arr2)
{-# INLINE splitAtM #-}
splitAt ::
(Extract r ix e, r' ~ EltRepr r ix)
=> Dim
-> Int
-> Array r ix e
-> Maybe (Array r' ix e, Array r' ix e)
splitAt dim i arr = do
let Sz sz = size arr
eIx <- setDimM sz dim i
sIx <- setDimM zeroIndex dim i
arr1 <- extractFromTo zeroIndex eIx arr
arr2 <- extractFromTo sIx sz arr
return (arr1, arr2)
{-# INLINE splitAt #-}
{-# DEPRECATED splitAt "In favor of a more general `splitAtM`" #-}
splitAt' :: (Extract r ix e, r' ~ EltRepr r ix) =>
Dim -> Int -> Array r ix e -> (Array r' ix e, Array r' ix e)
splitAt' dim i arr = either throw id $ splitAtM dim i arr
{-# INLINE splitAt' #-}
downsample :: Source r ix e => Stride ix -> Array r ix e -> Array DL ix e
downsample !stride arr =
DLArray
{ dlComp = getComp arr
, dlSize = resultSize
, dlDefault = Nothing
, dlLoad =
\scheduler startAt dlWrite ->
splitLinearlyWithStartAtM_
scheduler
startAt
(totalElem resultSize)
(pure . unsafeLinearWriteWithStride)
dlWrite
}
where
resultSize = strideSize stride (size arr)
strideIx = unStride stride
unsafeLinearWriteWithStride =
unsafeIndex arr . liftIndex2 (*) strideIx . fromLinearIndex resultSize
{-# INLINE unsafeLinearWriteWithStride #-}
{-# INLINE downsample #-}
upsample
:: Load r ix e => e -> Stride ix -> Array r ix e -> Array DL ix e
upsample !fillWith !safeStride arr =
DLArray
{ dlComp = getComp arr
, dlSize = newsz
, dlDefault = Nothing
, dlLoad =
\scheduler startAt dlWrite -> do
unless (stride == pureIndex 1) $
loopM_ startAt (< totalElem newsz) (+ 1) (`dlWrite` fillWith)
loadArrayM scheduler arr (\i -> dlWrite (adjustLinearStride (i + startAt)))
}
where
adjustLinearStride = toLinearIndex newsz . timesStride . fromLinearIndex sz
{-# INLINE adjustLinearStride #-}
timesStride !ix = liftIndex2 (*) stride ix
{-# INLINE timesStride #-}
!stride = unStride safeStride
!sz = size arr
!newsz = SafeSz (timesStride $ unSz sz)
{-# INLINE upsample #-}
traverse
:: (Source r1 ix1 e1, Index ix)
=> Sz ix
-> ((ix1 -> e1) -> ix -> e)
-> Array r1 ix1 e1
-> Array D ix e
traverse sz f arr1 = makeArray (getComp arr1) sz (f (evaluate' arr1))
{-# INLINE traverse #-}
{-# DEPRECATED traverse "In favor of more general `transform'`" #-}
traverse2
:: (Source r1 ix1 e1, Source r2 ix2 e2, Index ix)
=> Sz ix
-> ((ix1 -> e1) -> (ix2 -> e2) -> ix -> e)
-> Array r1 ix1 e1
-> Array r2 ix2 e2
-> Array D ix e
traverse2 sz f arr1 arr2 =
makeArray (getComp arr1 <> getComp arr2) sz (f (evaluate' arr1) (evaluate' arr2))
{-# INLINE traverse2 #-}
{-# DEPRECATED traverse2 "In favor of more general `transform2'`" #-}
transformM ::
forall r ix e r' ix' e' a m.
(Mutable r ix e, Source r' ix' e', MonadUnliftIO m, PrimMonad m, MonadThrow m)
=> (Sz ix' -> m (Sz ix, a))
-> (a -> (ix' -> m e') -> ix -> m e)
-> Array r' ix' e'
-> m (Array r ix e)
transformM getSzM getM arr = do
(sz, a) <- getSzM (size arr)
generateArray (getComp arr) sz (getM a (evaluateM arr))
{-# INLINE transformM #-}
transform' ::
(Source r' ix' e', Index ix)
=> (Sz ix' -> (Sz ix, a))
-> (a -> (ix' -> e') -> ix -> e)
-> Array r' ix' e'
-> Array D ix e
transform' getSz get arr = makeArray (getComp arr) sz (get a (evaluate' arr))
where
(sz, a) = getSz (size arr)
{-# INLINE transform' #-}
transform2M ::
(Mutable r ix e, Source r1 ix1 e1, Source r2 ix2 e2, MonadUnliftIO m, PrimMonad m, MonadThrow m)
=> (Sz ix1 -> Sz ix2 -> m (Sz ix, a))
-> (a -> (ix1 -> m e1) -> (ix2 -> m e2) -> ix -> m e)
-> Array r1 ix1 e1
-> Array r2 ix2 e2
-> m (Array r ix e)
transform2M getSzM getM arr1 arr2 = do
(sz, a) <- getSzM (size arr1) (size arr2)
generateArray (getComp arr1 <> getComp arr2) sz (getM a (evaluateM arr1) (evaluateM arr2))
{-# INLINE transform2M #-}
transform2' ::
(Source r1 ix1 e1, Source r2 ix2 e2, Index ix)
=> (Sz ix1 -> Sz ix2 -> (Sz ix, a))
-> (a -> (ix1 -> e1) -> (ix2 -> e2) -> ix -> e)
-> Array r1 ix1 e1
-> Array r2 ix2 e2
-> Array D ix e
transform2' getSz get arr1 arr2 =
makeArray (getComp arr1 <> getComp arr2) sz (get a (evaluate' arr1) (evaluate' arr2))
where
(sz, a) = getSz (size arr1) (size arr2)
{-# INLINE transform2' #-}
zoomWithGrid ::
Source r ix e
=> e
-> Stride ix
-> Array r ix e
-> Array DL ix e
zoomWithGrid gridVal (Stride zoomFactor) arr =
unsafeMakeLoadArray Seq newSz (Just gridVal) $ \scheduler _ writeElement ->
iforSchedulerM_ scheduler arr $ \ !ix !e -> do
let !kix = liftIndex2 (*) ix kx
mapM_ (\ !ix' -> writeElement (toLinearIndex newSz ix') e) $
range Seq (liftIndex (+1) kix) (liftIndex2 (+) kix kx)
where
!kx = liftIndex (+1) zoomFactor
!lastNewIx = liftIndex2 (*) kx $ unSz (size arr)
!newSz = Sz (liftIndex (+1) lastNewIx)
{-# INLINE zoomWithGrid #-}