{-# LANGUAGE TypeFamilies #-}
module Data.Array.Comfort.Storable.Mutable.Private where

import qualified Data.Array.Comfort.Shape as Shape

import qualified Foreign.Marshal.Array.Guarded as Alloc
import Foreign.Marshal.Array (copyArray, pokeArray, peekArray)
import Foreign.Storable (Storable, pokeElemOff, peekElemOff)
import Foreign.Ptr (Ptr)

import Control.Monad.Primitive (PrimMonad, unsafeIOToPrim)
import Control.Monad.ST (ST)
import Control.Monad (liftM)
import Control.Applicative ((<$>))

import Data.Tuple.HT (mapFst)

import qualified Prelude as P
import Prelude hiding (read, show)


data Array (m :: * -> *) sh a =
   Array {
      Array m sh a -> sh
shape :: sh,
      Array m sh a -> MutablePtr a
buffer :: Alloc.MutablePtr a
   }

type STArray s = Array (ST s)
type IOArray = Array IO


copy ::
   (PrimMonad m, Shape.C sh, Storable a) =>
   Array m sh a -> m (Array m sh a)
copy :: Array m sh a -> m (Array m sh a)
copy (Array sh
sh MutablePtr a
srcFPtr) =
   sh -> (Int -> Ptr a -> IO ()) -> m (Array m sh a)
forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO ()) -> m (Array m sh a)
unsafeCreateWithSize sh
sh ((Int -> Ptr a -> IO ()) -> m (Array m sh a))
-> (Int -> Ptr a -> IO ()) -> m (Array m sh a)
forall a b. (a -> b) -> a -> b
$ \Int
n Ptr a
dstPtr ->
   MutablePtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. MutablePtr a -> (Ptr a -> IO b) -> IO b
Alloc.withMutablePtr MutablePtr a
srcFPtr ((Ptr a -> IO ()) -> IO ()) -> (Ptr a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
srcPtr ->
      Ptr a -> Ptr a -> Int -> IO ()
forall a. Storable a => Ptr a -> Ptr a -> Int -> IO ()
copyArray Ptr a
dstPtr Ptr a
srcPtr Int
n


create ::
   (Shape.C sh, Storable a) =>
   sh -> (Ptr a -> IO ()) -> IO (IOArray sh a)
create :: sh -> (Ptr a -> IO ()) -> IO (IOArray sh a)
create sh
sh Ptr a -> IO ()
f = sh -> (Int -> Ptr a -> IO ()) -> IO (IOArray sh a)
forall sh a.
(C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO ()) -> IO (IOArray sh a)
createWithSize sh
sh ((Int -> Ptr a -> IO ()) -> IO (IOArray sh a))
-> (Int -> Ptr a -> IO ()) -> IO (IOArray sh a)
forall a b. (a -> b) -> a -> b
$ (Ptr a -> IO ()) -> Int -> Ptr a -> IO ()
forall a b. a -> b -> a
const Ptr a -> IO ()
f

createWithSize ::
   (Shape.C sh, Storable a) =>
   sh -> (Int -> Ptr a -> IO ()) -> IO (IOArray sh a)
createWithSize :: sh -> (Int -> Ptr a -> IO ()) -> IO (IOArray sh a)
createWithSize sh
sh Int -> Ptr a -> IO ()
f =
   (IOArray sh a, ()) -> IOArray sh a
forall a b. (a, b) -> a
fst ((IOArray sh a, ()) -> IOArray sh a)
-> IO (IOArray sh a, ()) -> IO (IOArray sh a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> sh -> (Int -> Ptr a -> IO ()) -> IO (IOArray sh a, ())
forall sh a b.
(C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO b) -> IO (IOArray sh a, b)
createWithSizeAndResult sh
sh Int -> Ptr a -> IO ()
f

createWithSizeAndResult ::
   (Shape.C sh, Storable a) =>
   sh -> (Int -> Ptr a -> IO b) -> IO (IOArray sh a, b)
createWithSizeAndResult :: sh -> (Int -> Ptr a -> IO b) -> IO (IOArray sh a, b)
createWithSizeAndResult sh
sh Int -> Ptr a -> IO b
f = do
   let size :: Int
size = sh -> Int
forall sh. C sh => sh -> Int
Shape.size sh
sh
   MutablePtr a
mfptr <- Int -> IO (MutablePtr a)
forall a. Storable a => Int -> IO (MutablePtr a)
Alloc.new Int
size
   b
b <- MutablePtr a -> (Ptr a -> IO b) -> IO b
forall a b. MutablePtr a -> (Ptr a -> IO b) -> IO b
Alloc.withMutablePtr MutablePtr a
mfptr ((Ptr a -> IO b) -> IO b) -> (Ptr a -> IO b) -> IO b
forall a b. (a -> b) -> a -> b
$ Int -> Ptr a -> IO b
f Int
size
   (IOArray sh a, b) -> IO (IOArray sh a, b)
forall (m :: * -> *) a. Monad m => a -> m a
return (sh -> MutablePtr a -> IOArray sh a
forall (m :: * -> *) sh a. sh -> MutablePtr a -> Array m sh a
Array sh
sh MutablePtr a
mfptr, b
b)


unsafeCreate ::
   (PrimMonad m, Shape.C sh, Storable a) =>
   sh -> (Ptr a -> IO ()) -> m (Array m sh a)
unsafeCreate :: sh -> (Ptr a -> IO ()) -> m (Array m sh a)
unsafeCreate sh
sh Ptr a -> IO ()
f = sh -> (Int -> Ptr a -> IO ()) -> m (Array m sh a)
forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO ()) -> m (Array m sh a)
unsafeCreateWithSize sh
sh ((Int -> Ptr a -> IO ()) -> m (Array m sh a))
-> (Int -> Ptr a -> IO ()) -> m (Array m sh a)
forall a b. (a -> b) -> a -> b
$ (Ptr a -> IO ()) -> Int -> Ptr a -> IO ()
forall a b. a -> b -> a
const Ptr a -> IO ()
f

unsafeCreateWithSize ::
   (PrimMonad m, Shape.C sh, Storable a) =>
   sh -> (Int -> Ptr a -> IO ()) -> m (Array m sh a)
unsafeCreateWithSize :: sh -> (Int -> Ptr a -> IO ()) -> m (Array m sh a)
unsafeCreateWithSize sh
sh Int -> Ptr a -> IO ()
f =
   ((Array m sh a, ()) -> Array m sh a)
-> m (Array m sh a, ()) -> m (Array m sh a)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM (Array m sh a, ()) -> Array m sh a
forall a b. (a, b) -> a
fst (m (Array m sh a, ()) -> m (Array m sh a))
-> m (Array m sh a, ()) -> m (Array m sh a)
forall a b. (a -> b) -> a -> b
$ sh -> (Int -> Ptr a -> IO ()) -> m (Array m sh a, ())
forall (m :: * -> *) sh a b.
(PrimMonad m, C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO b) -> m (Array m sh a, b)
unsafeCreateWithSizeAndResult sh
sh Int -> Ptr a -> IO ()
f

unsafeCreateWithSizeAndResult ::
   (PrimMonad m, Shape.C sh, Storable a) =>
   sh -> (Int -> Ptr a -> IO b) -> m (Array m sh a, b)
unsafeCreateWithSizeAndResult :: sh -> (Int -> Ptr a -> IO b) -> m (Array m sh a, b)
unsafeCreateWithSizeAndResult sh
sh Int -> Ptr a -> IO b
f =
   IO (Array m sh a, b) -> m (Array m sh a, b)
forall (m :: * -> *) a. PrimMonad m => IO a -> m a
unsafeIOToPrim (IO (Array m sh a, b) -> m (Array m sh a, b))
-> IO (Array m sh a, b) -> m (Array m sh a, b)
forall a b. (a -> b) -> a -> b
$
   ((IOArray sh a, b) -> (Array m sh a, b))
-> IO (IOArray sh a, b) -> IO (Array m sh a, b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((IOArray sh a -> Array m sh a)
-> (IOArray sh a, b) -> (Array m sh a, b)
forall a c b. (a -> c) -> (a, b) -> (c, b)
mapFst IOArray sh a -> Array m sh a
forall (m :: * -> *) sh a.
PrimMonad m =>
IOArray sh a -> Array m sh a
unsafeArrayIOToPrim) (IO (IOArray sh a, b) -> IO (Array m sh a, b))
-> IO (IOArray sh a, b) -> IO (Array m sh a, b)
forall a b. (a -> b) -> a -> b
$ sh -> (Int -> Ptr a -> IO b) -> IO (IOArray sh a, b)
forall sh a b.
(C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO b) -> IO (IOArray sh a, b)
createWithSizeAndResult sh
sh Int -> Ptr a -> IO b
f

unsafeArrayIOToPrim :: (PrimMonad m) => IOArray sh a -> Array m sh a
unsafeArrayIOToPrim :: IOArray sh a -> Array m sh a
unsafeArrayIOToPrim (Array sh
sh MutablePtr a
fptr) = sh -> MutablePtr a -> Array m sh a
forall (m :: * -> *) sh a. sh -> MutablePtr a -> Array m sh a
Array sh
sh MutablePtr a
fptr


show ::
   (PrimMonad m, Shape.C sh, Show sh, Storable a, Show a) =>
   Array m sh a -> m String
show :: Array m sh a -> m String
show Array m sh a
arr = do
   [a]
xs <- Array m sh a -> m [a]
forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
Array m sh a -> m [a]
toList Array m sh a
arr
   String -> m String
forall (m :: * -> *) a. Monad m => a -> m a
return (String -> m String) -> String -> m String
forall a b. (a -> b) -> a -> b
$
      String
"StorableArray.fromList " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> sh -> String -> String
forall a. Show a => Int -> a -> String -> String
showsPrec Int
11 (Array m sh a -> sh
forall (m :: * -> *) sh a. Array m sh a -> sh
shape Array m sh a
arr) (Char
' ' Char -> String -> String
forall a. a -> [a] -> [a]
: [a] -> String
forall a. Show a => a -> String
P.show [a]
xs)

withArrayPtr :: (PrimMonad m) => Alloc.MutablePtr a -> (Ptr a -> IO b) -> m b
withArrayPtr :: MutablePtr a -> (Ptr a -> IO b) -> m b
withArrayPtr MutablePtr a
fptr = IO b -> m b
forall (m :: * -> *) a. PrimMonad m => IO a -> m a
unsafeIOToPrim (IO b -> m b)
-> ((Ptr a -> IO b) -> IO b) -> (Ptr a -> IO b) -> m b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MutablePtr a -> (Ptr a -> IO b) -> IO b
forall a b. MutablePtr a -> (Ptr a -> IO b) -> IO b
Alloc.withMutablePtr MutablePtr a
fptr

withPtr :: (PrimMonad m) => Array m sh a -> (Ptr a -> IO b) -> m b
withPtr :: Array m sh a -> (Ptr a -> IO b) -> m b
withPtr (Array sh
_sh MutablePtr a
fptr) = MutablePtr a -> (Ptr a -> IO b) -> m b
forall (m :: * -> *) a b.
PrimMonad m =>
MutablePtr a -> (Ptr a -> IO b) -> m b
withArrayPtr MutablePtr a
fptr

read ::
   (PrimMonad m, Shape.Indexed sh, Storable a) =>
   Array m sh a -> Shape.Index sh -> m a
read :: Array m sh a -> Index sh -> m a
read (Array sh
sh MutablePtr a
fptr) Index sh
ix =
   MutablePtr a -> (Ptr a -> IO a) -> m a
forall (m :: * -> *) a b.
PrimMonad m =>
MutablePtr a -> (Ptr a -> IO b) -> m b
withArrayPtr MutablePtr a
fptr ((Ptr a -> IO a) -> m a) -> (Ptr a -> IO a) -> m a
forall a b. (a -> b) -> a -> b
$ (Ptr a -> Int -> IO a) -> Int -> Ptr a -> IO a
forall a b c. (a -> b -> c) -> b -> a -> c
flip Ptr a -> Int -> IO a
forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (sh -> Index sh -> Int
forall sh. Indexed sh => sh -> Index sh -> Int
Shape.uncheckedOffset sh
sh Index sh
ix)

write ::
   (PrimMonad m, Shape.Indexed sh, Storable a) =>
   Array m sh a -> Shape.Index sh -> a -> m ()
write :: Array m sh a -> Index sh -> a -> m ()
write (Array sh
sh MutablePtr a
fptr) Index sh
ix a
a =
   MutablePtr a -> (Ptr a -> IO ()) -> m ()
forall (m :: * -> *) a b.
PrimMonad m =>
MutablePtr a -> (Ptr a -> IO b) -> m b
withArrayPtr MutablePtr a
fptr ((Ptr a -> IO ()) -> m ()) -> (Ptr a -> IO ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
ptr -> Ptr a -> Int -> a -> IO ()
forall a. Storable a => Ptr a -> Int -> a -> IO ()
pokeElemOff Ptr a
ptr (sh -> Index sh -> Int
forall sh. Indexed sh => sh -> Index sh -> Int
Shape.uncheckedOffset sh
sh Index sh
ix) a
a

update ::
   (PrimMonad m, Shape.Indexed sh, Storable a) =>
   Array m sh a -> Shape.Index sh -> (a -> a) -> m ()
update :: Array m sh a -> Index sh -> (a -> a) -> m ()
update (Array sh
sh MutablePtr a
fptr) Index sh
ix a -> a
f =
   let k :: Int
k = sh -> Index sh -> Int
forall sh. Indexed sh => sh -> Index sh -> Int
Shape.uncheckedOffset sh
sh Index sh
ix
   in MutablePtr a -> (Ptr a -> IO ()) -> m ()
forall (m :: * -> *) a b.
PrimMonad m =>
MutablePtr a -> (Ptr a -> IO b) -> m b
withArrayPtr MutablePtr a
fptr ((Ptr a -> IO ()) -> m ()) -> (Ptr a -> IO ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
ptr -> Ptr a -> Int -> a -> IO ()
forall a. Storable a => Ptr a -> Int -> a -> IO ()
pokeElemOff Ptr a
ptr Int
k (a -> IO ()) -> (a -> a) -> a -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> a
f (a -> IO ()) -> IO a -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr a -> Int -> IO a
forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff Ptr a
ptr Int
k

new :: (PrimMonad m, Shape.C sh, Storable a) => sh -> a -> m (Array m sh a)
new :: sh -> a -> m (Array m sh a)
new sh
sh a
x =
   sh -> (Int -> Ptr a -> IO ()) -> m (Array m sh a)
forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO ()) -> m (Array m sh a)
unsafeCreateWithSize sh
sh ((Int -> Ptr a -> IO ()) -> m (Array m sh a))
-> (Int -> Ptr a -> IO ()) -> m (Array m sh a)
forall a b. (a -> b) -> a -> b
$ \Int
size Ptr a
ptr -> Ptr a -> [a] -> IO ()
forall a. Storable a => Ptr a -> [a] -> IO ()
pokeArray Ptr a
ptr ([a] -> IO ()) -> [a] -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> a -> [a]
forall a. Int -> a -> [a]
replicate Int
size a
x

toList :: (PrimMonad m, Shape.C sh, Storable a) => Array m sh a -> m [a]
toList :: Array m sh a -> m [a]
toList (Array sh
sh MutablePtr a
fptr) = MutablePtr a -> (Ptr a -> IO [a]) -> m [a]
forall (m :: * -> *) a b.
PrimMonad m =>
MutablePtr a -> (Ptr a -> IO b) -> m b
withArrayPtr MutablePtr a
fptr ((Ptr a -> IO [a]) -> m [a]) -> (Ptr a -> IO [a]) -> m [a]
forall a b. (a -> b) -> a -> b
$ Int -> Ptr a -> IO [a]
forall a. Storable a => Int -> Ptr a -> IO [a]
peekArray (sh -> Int
forall sh. C sh => sh -> Int
Shape.size sh
sh)

fromList ::
   (PrimMonad m, Shape.C sh, Storable a) => sh -> [a] -> m (Array m sh a)
fromList :: sh -> [a] -> m (Array m sh a)
fromList sh
sh [a]
xs = sh -> (Ptr a -> IO ()) -> m (Array m sh a)
forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
sh -> (Ptr a -> IO ()) -> m (Array m sh a)
unsafeCreate sh
sh ((Ptr a -> IO ()) -> m (Array m sh a))
-> (Ptr a -> IO ()) -> m (Array m sh a)
forall a b. (a -> b) -> a -> b
$ \Ptr a
ptr -> Ptr a -> [a] -> IO ()
forall a. Storable a => Ptr a -> [a] -> IO ()
pokeArray Ptr a
ptr [a]
xs

vectorFromList ::
   (PrimMonad m, Storable a) => [a] -> m (Array m (Shape.ZeroBased Int) a)
vectorFromList :: [a] -> m (Array m (ZeroBased Int) a)
vectorFromList [a]
xs =
   ZeroBased Int -> (Ptr a -> IO ()) -> m (Array m (ZeroBased Int) a)
forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
sh -> (Ptr a -> IO ()) -> m (Array m sh a)
unsafeCreate (Int -> ZeroBased Int
forall n. n -> ZeroBased n
Shape.ZeroBased (Int -> ZeroBased Int) -> Int -> ZeroBased Int
forall a b. (a -> b) -> a -> b
$ [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
xs) ((Ptr a -> IO ()) -> m (Array m (ZeroBased Int) a))
-> (Ptr a -> IO ()) -> m (Array m (ZeroBased Int) a)
forall a b. (a -> b) -> a -> b
$ (Ptr a -> [a] -> IO ()) -> [a] -> Ptr a -> IO ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip Ptr a -> [a] -> IO ()
forall a. Storable a => Ptr a -> [a] -> IO ()
pokeArray [a]
xs