module Data.Array.Comfort.Storable (
   Array,
   shape,
   reshape,
   mapShape,

   (!),
   Array.toList,
   Array.fromList,
   Array.vectorFromList,

   Array.map,
   Array.mapWithIndex,
   (//),
   accumulate,
   fromAssociations,
   ) where

import qualified Data.Array.Comfort.Storable.Mutable.Internal as MutArray
import qualified Data.Array.Comfort.Storable.Internal as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Internal (Array)

import Foreign.Storable (Storable)

import Control.Monad.ST (runST)

import Data.Foldable (forM_)

import Text.Printf (printf)

import Prelude hiding (map)


shape :: Array sh a -> sh
shape = Array.shape

reshape :: (Shape.C sh0, Shape.C sh1) => sh1 -> Array sh0 a -> Array sh1 a
reshape sh1 arr =
   let n0 = Shape.size $ shape arr
       n1 = Shape.size sh1
   in if n0 == n1
         then Array.reshape sh1 arr
         else error $
              printf
                 ("Array.Comfort.Storable.reshape: " ++
                  "different sizes of old (%d) and new (%d) shape")
                 n0 n1

mapShape ::
   (Shape.C sh0, Shape.C sh1) => (sh0 -> sh1) -> Array sh0 a -> Array sh1 a
mapShape f arr = reshape (f $ shape arr) arr


infixl 9 !

(!) :: (Shape.Indexed sh, Storable a) => Array sh a -> Shape.Index sh -> a
(!) arr ix = runST (do
   marr <- MutArray.unsafeThaw arr
   MutArray.read marr ix)


(//) ::
   (Shape.Indexed sh, Storable a) =>
   Array sh a -> [(Shape.Index sh, a)] -> Array sh a
(//) arr xs = runST (do
   marr <- MutArray.thaw arr
   forM_ xs $ uncurry $ MutArray.write marr
   MutArray.unsafeFreeze marr)

accumulate ::
   (Shape.Indexed sh, Storable a) =>
   (a -> b -> a) -> Array sh a -> [(Shape.Index sh, b)] -> Array sh a
accumulate f arr xs = runST (do
   marr <- MutArray.thaw arr
   forM_ xs $ \(ix,b) -> MutArray.update marr ix $ flip f b
   MutArray.unsafeFreeze marr)

fromAssociations ::
   (Shape.Indexed sh, Storable a) =>
   sh -> a -> [(Shape.Index sh, a)] -> Array sh a
fromAssociations sh a xs = runST (do
   marr <- MutArray.new sh a
   forM_ xs $ uncurry $ MutArray.write marr
   MutArray.unsafeFreeze marr)