{-# LANGUAGE TypeFamilies #-}
module Data.Array.Comfort.Storable.Mutable.Private where

import qualified Data.Array.Comfort.Shape as Shape

import qualified Foreign.Marshal.Array.Guarded as Alloc
import Foreign.Marshal.Array (copyArray, pokeArray, peekArray)
import Foreign.Storable (Storable, pokeElemOff, peekElemOff)
import Foreign.Ptr (Ptr)

import Control.Monad.Primitive (PrimMonad, unsafeIOToPrim)
import Control.Monad.ST (ST)
import Control.Monad (liftM)
import Control.Applicative ((<$>))

import Data.Either.HT (maybeRight)
import Data.Tuple.HT (mapFst)

import qualified Prelude as P
import Prelude hiding (read, show)


data Array (m :: * -> *) sh a =
   Array {
      forall (m :: * -> *) sh a. Array m sh a -> sh
shape :: sh,
      forall (m :: * -> *) sh a. Array m sh a -> MutablePtr a
buffer :: Alloc.MutablePtr a
   }

type STArray s = Array (ST s)
type IOArray = Array IO


copy ::
   (PrimMonad m, Shape.C sh, Storable a) =>
   Array m sh a -> m (Array m sh a)
copy :: forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
Array m sh a -> m (Array m sh a)
copy (Array sh
sh MutablePtr a
srcFPtr) =
   sh -> (Int -> Ptr a -> IO ()) -> m (Array m sh a)
forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO ()) -> m (Array m sh a)
unsafeCreateWithSize sh
sh ((Int -> Ptr a -> IO ()) -> m (Array m sh a))
-> (Int -> Ptr a -> IO ()) -> m (Array m sh a)
forall a b. (a -> b) -> a -> b
$ \Int
n Ptr a
dstPtr ->
   MutablePtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. MutablePtr a -> (Ptr a -> IO b) -> IO b
Alloc.withMutablePtr MutablePtr a
srcFPtr ((Ptr a -> IO ()) -> IO ()) -> (Ptr a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
srcPtr ->
      Ptr a -> Ptr a -> Int -> IO ()
forall a. Storable a => Ptr a -> Ptr a -> Int -> IO ()
copyArray Ptr a
dstPtr Ptr a
srcPtr Int
n


create ::
   (Shape.C sh, Storable a) =>
   sh -> (Ptr a -> IO ()) -> IO (IOArray sh a)
create :: forall sh a.
(C sh, Storable a) =>
sh -> (Ptr a -> IO ()) -> IO (IOArray sh a)
create sh
sh Ptr a -> IO ()
f = sh -> (Int -> Ptr a -> IO ()) -> IO (IOArray sh a)
forall sh a.
(C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO ()) -> IO (IOArray sh a)
createWithSize sh
sh ((Int -> Ptr a -> IO ()) -> IO (IOArray sh a))
-> (Int -> Ptr a -> IO ()) -> IO (IOArray sh a)
forall a b. (a -> b) -> a -> b
$ (Ptr a -> IO ()) -> Int -> Ptr a -> IO ()
forall a b. a -> b -> a
const Ptr a -> IO ()
f

createWithSize ::
   (Shape.C sh, Storable a) =>
   sh -> (Int -> Ptr a -> IO ()) -> IO (IOArray sh a)
createWithSize :: forall sh a.
(C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO ()) -> IO (IOArray sh a)
createWithSize sh
sh Int -> Ptr a -> IO ()
f =
   (IOArray sh a, ()) -> IOArray sh a
forall a b. (a, b) -> a
fst ((IOArray sh a, ()) -> IOArray sh a)
-> IO (IOArray sh a, ()) -> IO (IOArray sh a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> sh -> (Int -> Ptr a -> IO ()) -> IO (IOArray sh a, ())
forall sh a b.
(C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO b) -> IO (IOArray sh a, b)
createWithSizeAndResult sh
sh Int -> Ptr a -> IO ()
f

createWithSizeAndResult ::
   (Shape.C sh, Storable a) =>
   sh -> (Int -> Ptr a -> IO b) -> IO (IOArray sh a, b)
createWithSizeAndResult :: forall sh a b.
(C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO b) -> IO (IOArray sh a, b)
createWithSizeAndResult sh
sh Int -> Ptr a -> IO b
f = do
   let size :: Int
size = sh -> Int
forall sh. C sh => sh -> Int
Shape.size sh
sh
   MutablePtr a
mfptr <- Int -> IO (MutablePtr a)
forall a. Storable a => Int -> IO (MutablePtr a)
Alloc.new Int
size
   b
b <- MutablePtr a -> (Ptr a -> IO b) -> IO b
forall a b. MutablePtr a -> (Ptr a -> IO b) -> IO b
Alloc.withMutablePtr MutablePtr a
mfptr ((Ptr a -> IO b) -> IO b) -> (Ptr a -> IO b) -> IO b
forall a b. (a -> b) -> a -> b
$ Int -> Ptr a -> IO b
f Int
size
   (IOArray sh a, b) -> IO (IOArray sh a, b)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (sh -> MutablePtr a -> IOArray sh a
forall (m :: * -> *) sh a. sh -> MutablePtr a -> Array m sh a
Array sh
sh MutablePtr a
mfptr, b
b)


unsafeCreate ::
   (PrimMonad m, Shape.C sh, Storable a) =>
   sh -> (Ptr a -> IO ()) -> m (Array m sh a)
unsafeCreate :: forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
sh -> (Ptr a -> IO ()) -> m (Array m sh a)
unsafeCreate sh
sh Ptr a -> IO ()
f = sh -> (Int -> Ptr a -> IO ()) -> m (Array m sh a)
forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO ()) -> m (Array m sh a)
unsafeCreateWithSize sh
sh ((Int -> Ptr a -> IO ()) -> m (Array m sh a))
-> (Int -> Ptr a -> IO ()) -> m (Array m sh a)
forall a b. (a -> b) -> a -> b
$ (Ptr a -> IO ()) -> Int -> Ptr a -> IO ()
forall a b. a -> b -> a
const Ptr a -> IO ()
f

unsafeCreateWithSize ::
   (PrimMonad m, Shape.C sh, Storable a) =>
   sh -> (Int -> Ptr a -> IO ()) -> m (Array m sh a)
unsafeCreateWithSize :: forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO ()) -> m (Array m sh a)
unsafeCreateWithSize sh
sh Int -> Ptr a -> IO ()
f =
   ((Array m sh a, ()) -> Array m sh a)
-> m (Array m sh a, ()) -> m (Array m sh a)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM (Array m sh a, ()) -> Array m sh a
forall a b. (a, b) -> a
fst (m (Array m sh a, ()) -> m (Array m sh a))
-> m (Array m sh a, ()) -> m (Array m sh a)
forall a b. (a -> b) -> a -> b
$ sh -> (Int -> Ptr a -> IO ()) -> m (Array m sh a, ())
forall (m :: * -> *) sh a b.
(PrimMonad m, C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO b) -> m (Array m sh a, b)
unsafeCreateWithSizeAndResult sh
sh Int -> Ptr a -> IO ()
f

unsafeCreateWithSizeAndResult ::
   (PrimMonad m, Shape.C sh, Storable a) =>
   sh -> (Int -> Ptr a -> IO b) -> m (Array m sh a, b)
unsafeCreateWithSizeAndResult :: forall (m :: * -> *) sh a b.
(PrimMonad m, C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO b) -> m (Array m sh a, b)
unsafeCreateWithSizeAndResult sh
sh Int -> Ptr a -> IO b
f =
   IO (Array m sh a, b) -> m (Array m sh a, b)
forall (m :: * -> *) a. PrimMonad m => IO a -> m a
unsafeIOToPrim (IO (Array m sh a, b) -> m (Array m sh a, b))
-> IO (Array m sh a, b) -> m (Array m sh a, b)
forall a b. (a -> b) -> a -> b
$
   ((IOArray sh a, b) -> (Array m sh a, b))
-> IO (IOArray sh a, b) -> IO (Array m sh a, b)
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((IOArray sh a -> Array m sh a)
-> (IOArray sh a, b) -> (Array m sh a, b)
forall a c b. (a -> c) -> (a, b) -> (c, b)
mapFst IOArray sh a -> Array m sh a
forall (m :: * -> *) sh a.
PrimMonad m =>
IOArray sh a -> Array m sh a
unsafeArrayIOToPrim) (IO (IOArray sh a, b) -> IO (Array m sh a, b))
-> IO (IOArray sh a, b) -> IO (Array m sh a, b)
forall a b. (a -> b) -> a -> b
$ sh -> (Int -> Ptr a -> IO b) -> IO (IOArray sh a, b)
forall sh a b.
(C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO b) -> IO (IOArray sh a, b)
createWithSizeAndResult sh
sh Int -> Ptr a -> IO b
f

unsafeArrayIOToPrim :: (PrimMonad m) => IOArray sh a -> Array m sh a
unsafeArrayIOToPrim :: forall (m :: * -> *) sh a.
PrimMonad m =>
IOArray sh a -> Array m sh a
unsafeArrayIOToPrim (Array sh
sh MutablePtr a
fptr) = sh -> MutablePtr a -> Array m sh a
forall (m :: * -> *) sh a. sh -> MutablePtr a -> Array m sh a
Array sh
sh MutablePtr a
fptr


show ::
   (PrimMonad m, Shape.C sh, Show sh, Storable a, Show a) =>
   Array m sh a -> m String
show :: forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Show sh, Storable a, Show a) =>
Array m sh a -> m String
show Array m sh a
arr = do
   [a]
xs <- Array m sh a -> m [a]
forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
Array m sh a -> m [a]
toList Array m sh a
arr
   String -> m String
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (String -> m String) -> String -> m String
forall a b. (a -> b) -> a -> b
$
      String
"StorableArray.fromList " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> sh -> String -> String
forall a. Show a => Int -> a -> String -> String
showsPrec Int
11 (Array m sh a -> sh
forall (m :: * -> *) sh a. Array m sh a -> sh
shape Array m sh a
arr) (Char
' ' Char -> String -> String
forall a. a -> [a] -> [a]
: [a] -> String
forall a. Show a => a -> String
P.show [a]
xs)

withArrayPtr :: (PrimMonad m) => Alloc.MutablePtr a -> (Ptr a -> IO b) -> m b
withArrayPtr :: forall (m :: * -> *) a b.
PrimMonad m =>
MutablePtr a -> (Ptr a -> IO b) -> m b
withArrayPtr MutablePtr a
fptr = IO b -> m b
forall (m :: * -> *) a. PrimMonad m => IO a -> m a
unsafeIOToPrim (IO b -> m b)
-> ((Ptr a -> IO b) -> IO b) -> (Ptr a -> IO b) -> m b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MutablePtr a -> (Ptr a -> IO b) -> IO b
forall a b. MutablePtr a -> (Ptr a -> IO b) -> IO b
Alloc.withMutablePtr MutablePtr a
fptr

withPtr :: (PrimMonad m) => Array m sh a -> (Ptr a -> IO b) -> m b
withPtr :: forall (m :: * -> *) sh a b.
PrimMonad m =>
Array m sh a -> (Ptr a -> IO b) -> m b
withPtr (Array sh
_sh MutablePtr a
fptr) = MutablePtr a -> (Ptr a -> IO b) -> m b
forall (m :: * -> *) a b.
PrimMonad m =>
MutablePtr a -> (Ptr a -> IO b) -> m b
withArrayPtr MutablePtr a
fptr

read ::
   (PrimMonad m, Shape.Indexed sh, Storable a) =>
   Array m sh a -> Shape.Index sh -> m a
read :: forall (m :: * -> *) sh a.
(PrimMonad m, Indexed sh, Storable a) =>
Array m sh a -> Index sh -> m a
read (Array sh
sh MutablePtr a
fptr) Index sh
ix =
   MutablePtr a -> (Ptr a -> IO a) -> m a
forall (m :: * -> *) a b.
PrimMonad m =>
MutablePtr a -> (Ptr a -> IO b) -> m b
withArrayPtr MutablePtr a
fptr ((Ptr a -> IO a) -> m a) -> (Ptr a -> IO a) -> m a
forall a b. (a -> b) -> a -> b
$ (Ptr a -> Int -> IO a) -> Int -> Ptr a -> IO a
forall a b c. (a -> b -> c) -> b -> a -> c
flip Ptr a -> Int -> IO a
forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (sh -> Index sh -> Int
forall sh. Indexed sh => sh -> Index sh -> Int
Shape.uncheckedOffset sh
sh Index sh
ix)

readMaybe ::
   (PrimMonad m, Shape.Indexed sh, Storable a) =>
   Array m sh a -> Shape.Index sh -> Maybe (m a)
readMaybe :: forall (m :: * -> *) sh a.
(PrimMonad m, Indexed sh, Storable a) =>
Array m sh a -> Index sh -> Maybe (m a)
readMaybe Array m sh a
arr = Either String (m a) -> Maybe (m a)
forall a b. Either a b -> Maybe b
maybeRight (Either String (m a) -> Maybe (m a))
-> (Index sh -> Either String (m a)) -> Index sh -> Maybe (m a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array m sh a -> Index sh -> Either String (m a)
forall (m :: * -> *) sh a.
(PrimMonad m, Indexed sh, Storable a) =>
Array m sh a -> Index sh -> Either String (m a)
readEither Array m sh a
arr

readEither ::
   (PrimMonad m, Shape.Indexed sh, Storable a) =>
   Array m sh a -> Shape.Index sh -> Either String (m a)
readEither :: forall (m :: * -> *) sh a.
(PrimMonad m, Indexed sh, Storable a) =>
Array m sh a -> Index sh -> Either String (m a)
readEither (Array sh
sh MutablePtr a
fptr) Index sh
ix =
   (Int -> m a) -> Either String Int -> Either String (m a)
forall a b. (a -> b) -> Either String a -> Either String b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (MutablePtr a -> (Ptr a -> IO a) -> m a
forall (m :: * -> *) a b.
PrimMonad m =>
MutablePtr a -> (Ptr a -> IO b) -> m b
withArrayPtr MutablePtr a
fptr ((Ptr a -> IO a) -> m a) -> (Int -> Ptr a -> IO a) -> Int -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Ptr a -> Int -> IO a) -> Int -> Ptr a -> IO a
forall a b c. (a -> b -> c) -> b -> a -> c
flip Ptr a -> Int -> IO a
forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff) (Either String Int -> Either String (m a))
-> Either String Int -> Either String (m a)
forall a b. (a -> b) -> a -> b
$
   Result Checked Int -> Either String Int
forall a. Result Checked a -> Either String a
Shape.getChecked (Result Checked Int -> Either String Int)
-> Result Checked Int -> Either String Int
forall a b. (a -> b) -> a -> b
$ sh -> Index sh -> Result Checked Int
forall sh check.
(Indexed sh, Checking check) =>
sh -> Index sh -> Result check Int
forall check. Checking check => sh -> Index sh -> Result check Int
Shape.unifiedOffset sh
sh Index sh
ix

write ::
   (PrimMonad m, Shape.Indexed sh, Storable a) =>
   Array m sh a -> Shape.Index sh -> a -> m ()
write :: forall (m :: * -> *) sh a.
(PrimMonad m, Indexed sh, Storable a) =>
Array m sh a -> Index sh -> a -> m ()
write (Array sh
sh MutablePtr a
fptr) Index sh
ix a
a =
   MutablePtr a -> (Ptr a -> IO ()) -> m ()
forall (m :: * -> *) a b.
PrimMonad m =>
MutablePtr a -> (Ptr a -> IO b) -> m b
withArrayPtr MutablePtr a
fptr ((Ptr a -> IO ()) -> m ()) -> (Ptr a -> IO ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
ptr -> Ptr a -> Int -> a -> IO ()
forall a. Storable a => Ptr a -> Int -> a -> IO ()
pokeElemOff Ptr a
ptr (sh -> Index sh -> Int
forall sh. Indexed sh => sh -> Index sh -> Int
Shape.uncheckedOffset sh
sh Index sh
ix) a
a

update ::
   (PrimMonad m, Shape.Indexed sh, Storable a) =>
   Array m sh a -> Shape.Index sh -> (a -> a) -> m ()
update :: forall (m :: * -> *) sh a.
(PrimMonad m, Indexed sh, Storable a) =>
Array m sh a -> Index sh -> (a -> a) -> m ()
update (Array sh
sh MutablePtr a
fptr) Index sh
ix a -> a
f =
   let k :: Int
k = sh -> Index sh -> Int
forall sh. Indexed sh => sh -> Index sh -> Int
Shape.uncheckedOffset sh
sh Index sh
ix
   in MutablePtr a -> (Ptr a -> IO ()) -> m ()
forall (m :: * -> *) a b.
PrimMonad m =>
MutablePtr a -> (Ptr a -> IO b) -> m b
withArrayPtr MutablePtr a
fptr ((Ptr a -> IO ()) -> m ()) -> (Ptr a -> IO ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
ptr -> Ptr a -> Int -> a -> IO ()
forall a. Storable a => Ptr a -> Int -> a -> IO ()
pokeElemOff Ptr a
ptr Int
k (a -> IO ()) -> (a -> a) -> a -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> a
f (a -> IO ()) -> IO a -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr a -> Int -> IO a
forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff Ptr a
ptr Int
k

new :: (PrimMonad m, Shape.C sh, Storable a) => sh -> a -> m (Array m sh a)
new :: forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
sh -> a -> m (Array m sh a)
new sh
sh a
x =
   sh -> (Int -> Ptr a -> IO ()) -> m (Array m sh a)
forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO ()) -> m (Array m sh a)
unsafeCreateWithSize sh
sh ((Int -> Ptr a -> IO ()) -> m (Array m sh a))
-> (Int -> Ptr a -> IO ()) -> m (Array m sh a)
forall a b. (a -> b) -> a -> b
$ \Int
size Ptr a
ptr -> Ptr a -> [a] -> IO ()
forall a. Storable a => Ptr a -> [a] -> IO ()
pokeArray Ptr a
ptr ([a] -> IO ()) -> [a] -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> a -> [a]
forall a. Int -> a -> [a]
replicate Int
size a
x

toList :: (PrimMonad m, Shape.C sh, Storable a) => Array m sh a -> m [a]
toList :: forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
Array m sh a -> m [a]
toList (Array sh
sh MutablePtr a
fptr) = MutablePtr a -> (Ptr a -> IO [a]) -> m [a]
forall (m :: * -> *) a b.
PrimMonad m =>
MutablePtr a -> (Ptr a -> IO b) -> m b
withArrayPtr MutablePtr a
fptr ((Ptr a -> IO [a]) -> m [a]) -> (Ptr a -> IO [a]) -> m [a]
forall a b. (a -> b) -> a -> b
$ Int -> Ptr a -> IO [a]
forall a. Storable a => Int -> Ptr a -> IO [a]
peekArray (sh -> Int
forall sh. C sh => sh -> Int
Shape.size sh
sh)

fromList ::
   (PrimMonad m, Shape.C sh, Storable a) => sh -> [a] -> m (Array m sh a)
fromList :: forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
sh -> [a] -> m (Array m sh a)
fromList sh
sh [a]
xs = sh -> (Ptr a -> IO ()) -> m (Array m sh a)
forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
sh -> (Ptr a -> IO ()) -> m (Array m sh a)
unsafeCreate sh
sh ((Ptr a -> IO ()) -> m (Array m sh a))
-> (Ptr a -> IO ()) -> m (Array m sh a)
forall a b. (a -> b) -> a -> b
$ \Ptr a
ptr -> Ptr a -> [a] -> IO ()
forall a. Storable a => Ptr a -> [a] -> IO ()
pokeArray Ptr a
ptr [a]
xs

vectorFromList ::
   (PrimMonad m, Storable a) => [a] -> m (Array m (Shape.ZeroBased Int) a)
vectorFromList :: forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
[a] -> m (Array m (ZeroBased Int) a)
vectorFromList [a]
xs =
   ZeroBased Int -> (Ptr a -> IO ()) -> m (Array m (ZeroBased Int) a)
forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
sh -> (Ptr a -> IO ()) -> m (Array m sh a)
unsafeCreate (Int -> ZeroBased Int
forall n. n -> ZeroBased n
Shape.ZeroBased (Int -> ZeroBased Int) -> Int -> ZeroBased Int
forall a b. (a -> b) -> a -> b
$ [a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
xs) ((Ptr a -> IO ()) -> m (Array m (ZeroBased Int) a))
-> (Ptr a -> IO ()) -> m (Array m (ZeroBased Int) a)
forall a b. (a -> b) -> a -> b
$ (Ptr a -> [a] -> IO ()) -> [a] -> Ptr a -> IO ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip Ptr a -> [a] -> IO ()
forall a. Storable a => Ptr a -> [a] -> IO ()
pokeArray [a]
xs