{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Data.Massiv.Array.Delayed.Push
( DL(..)
, Array(..)
, toLoadArray
, makeLoadArrayS
, makeLoadArray
, unsafeMakeLoadArray
, fromStrideLoad
) where
import Data.Massiv.Core.Common
import Data.Massiv.Core.Index.Internal (Sz(SafeSz))
import qualified Data.Semigroup as Semigroup
import Prelude hiding (map, zipWith)
import Control.Applicative
#include "massiv.h"
data DL = DL deriving Show
data instance Array DL ix e = DLArray
{ dlComp :: !Comp
, dlSize :: !(Sz ix)
, dlDefault :: !(Maybe e)
, dlLoad :: forall m . Monad m
=> Scheduler m ()
-> Int
-> (Int -> e -> m ())
-> m ()
}
type instance EltRepr DL ix = DL
instance Index ix => Construct DL ix e where
setComp c arr = arr {dlComp = c}
{-# INLINE setComp #-}
makeArrayLinear comp sz f =
DLArray comp sz Nothing $ \scheduler startAt dlWrite ->
splitLinearlyWithStartAtM_ scheduler startAt (totalElem sz) (pure . f) dlWrite
{-# INLINE makeArrayLinear #-}
instance Index ix => Resize DL ix where
unsafeResize !sz arr = arr { dlSize = sz }
{-# INLINE unsafeResize #-}
instance Semigroup (Array DL Ix1 e) where
(<>) (DLArray c1 sz1 def1 load1) (DLArray c2 sz2 def2 load2) =
DLArray
{dlComp = c1 <> c2, dlSize = SafeSz (k + unSz sz2), dlDefault = def1 <|> def2, dlLoad = load}
where
!k = unSz sz1
load :: Monad m => Scheduler m () -> Int -> (Int -> e -> m ()) -> m ()
load scheduler startAt dlWrite = do
load1 scheduler startAt dlWrite
load2 scheduler (startAt + k) dlWrite
{-# INLINE load #-}
{-# INLINE (<>) #-}
instance Monoid (Array DL Ix1 e) where
mempty = makeArray Seq zeroSz (const (throwImpossible Uninitialized))
{-# INLINE mempty #-}
mappend = (Semigroup.<>)
{-# INLINE mappend #-}
makeLoadArrayS ::
Index ix =>
Sz ix
-> e
-> (forall m. Monad m => (ix -> e -> m Bool) -> m ())
-> Array DL ix e
makeLoadArrayS sz defVal writer =
DLArray Seq sz (Just defVal) $ \_scheduler !startAt uWrite ->
let safeWrite !ix !e
| isSafeIndex sz ix = uWrite (startAt + toLinearIndex sz ix) e >> pure True
| otherwise = pure False
{-# INLINE safeWrite #-}
in writer safeWrite
{-# INLINE makeLoadArrayS #-}
makeLoadArray ::
Comp
-> Sz ix
-> (forall m. Monad m => Scheduler m () -> Int -> (Int -> e -> m ()) -> m ())
-> Array DL ix e
makeLoadArray comp sz = DLArray comp sz Nothing
{-# INLINE makeLoadArray #-}
{-# DEPRECATED makeLoadArray "In favor of equivalent `unsafeMakeLoadArray` and safe `makeLoadArrayS`" #-}
unsafeMakeLoadArray ::
Comp
-> Sz ix
-> Maybe e
-> (forall m. Monad m => Scheduler m () -> Int -> (Int -> e -> m ()) -> m ())
-> Array DL ix e
unsafeMakeLoadArray = DLArray
{-# INLINE unsafeMakeLoadArray #-}
toLoadArray :: Load r ix e => Array r ix e -> Array DL ix e
toLoadArray arr =
DLArray (getComp arr) (size arr) Nothing $ \scheduler startAt dlWrite ->
loadArrayM scheduler arr (\ !i -> dlWrite (i + startAt))
{-# INLINE toLoadArray #-}
fromStrideLoad
:: StrideLoad r ix e => Stride ix -> Array r ix e -> Array DL ix e
fromStrideLoad stride arr =
DLArray (getComp arr) newsz Nothing $ \scheduler startAt dlWrite ->
loadArrayWithStrideM scheduler stride newsz arr (\ !i -> dlWrite (i + startAt))
where
newsz = strideSize stride (size arr)
{-# INLINE fromStrideLoad #-}
instance Index ix => Load DL ix e where
size = dlSize
{-# INLINE size #-}
getComp = dlComp
{-# INLINE getComp #-}
loadArrayM scheduler DLArray {dlLoad} = dlLoad scheduler 0
{-# INLINE loadArrayM #-}
defaultElement = dlDefault
{-# INLINE defaultElement #-}
instance Functor (Array DL ix) where
fmap f arr =
arr
{ dlLoad =
\scheduler startAt uWrite -> dlLoad arr scheduler startAt (\ !i e -> uWrite i (f e))
, dlDefault = f <$> dlDefault arr
}
{-# INLINE fmap #-}