module Numeric.BLAS.Vector.Mutable ( T, Sourced, C, shape, fromVector, sourcedFromVector, slice, sliceVector, slices, slicesVector, new, thawSlice, fromChunk, add, sub, mac, ) where import qualified Numeric.BLAS.Vector.SlicePrivate as VectorSlice import qualified Numeric.BLAS.Vector.Chunk as Chunk import qualified Numeric.BLAS.Subobject.Shape as Subshape import qualified Numeric.BLAS.Subobject.View as View import qualified Numeric.BLAS.Slice as Slice import qualified Numeric.BLAS.Scalar as Scalar import Numeric.BLAS.Vector.SlicePrivate ((<*|>)) import qualified Numeric.BLAS.FFI.Generic as Blas -- import qualified Numeric.BLAS.FFI.Complex as BlasComplex -- import qualified Numeric.BLAS.FFI.Real as BlasReal import qualified Numeric.Netlib.Class as Class import qualified Numeric.Netlib.Utility as Call import Numeric.BLAS.Private (fill) import qualified Data.Array.Comfort.Storable.Mutable.Private as Array import qualified Data.Array.Comfort.Shape as Shape import Data.Array.Comfort.Storable.Mutable.Private (Array(Array)) import qualified Foreign.Marshal.Array.Guarded as ForeignArray import Foreign.Marshal.Array (advancePtr) -- import Foreign.ForeignPtr (withForeignPtr, castForeignPtr) import Foreign.Storable (Storable) import Foreign.Ptr (Ptr) import Foreign.C.Types (CInt) import Control.Monad.Primitive (PrimMonad, unsafeIOToPrim) import Control.Monad.Trans.Cont (ContT(ContT), evalContT) import Control.Monad.IO.Class (liftIO) import Control.Applicative (liftA2, (<$>)) type Vector = Array type ShapeInt = Shape.ZeroBased Int -- ToDo: generalize 'lay' type parameter? newtype T m slice a = Cons (Array m (Subshape.Slice slice) a) data Sourced sh m slice a = Sourced (Slice.T slice) (Vector m sh a) -- deriving (Show) fromVector :: (Shape.C sh) => Vector m sh a -> T m sh a fromVector a = Cons $ Array.mapShape Subshape.fromVector a sourcedFromVector :: (Shape.C sh) => Vector m sh a -> Sourced sh m sh a sourcedFromVector a = Sourced (Slice.fromShape $ Array.shape a) a slice :: (Slice.T shA -> Slice.T shB) -> T m shA a -> T m shB a slice f (Cons xs) = Cons $ Array.mapShape (Subshape.focus f) xs sliceVector :: (Shape.C shA) => (Slice.T shA -> Slice.T shB) -> Vector m shA a -> T m shB a sliceVector f = slice f . fromVector slices :: (Functor f) => (Slice.T shA -> f (Slice.T shB)) -> T m shA a -> f (T m shB a) slices f (Cons (Array sh x)) = Cons . flip Array x <$> Subshape.focusMany f sh slicesVector :: (Functor f, Shape.C shA) => (Slice.T shA -> f (Slice.T shB)) -> Vector m shA a -> f (T m shB a) slicesVector f = slices f . fromVector -- ToDo: increment must not be zero (maybe even positive) class C v where shape :: (PrimMonad m) => v m slice a -> slice increment :: (PrimMonad m) => v m slice a -> Int startArg :: (PrimMonad m, Storable a) => v m slice a -> Call.FortranIO r (Ptr a) instance C Array where shape = Array.shape increment _arr = 1 startArg (Array _sh x) = ContT $ ForeignArray.withMutablePtr x instance C T where shape (Cons arr) = Subshape.shape $ Array.shape arr increment (Cons arr) = View.elemInc $ Subshape.layout $ Array.shape arr startArg (Cons (Array sh x)) = do xPtr <- ContT $ ForeignArray.withMutablePtr x return $ advancePtr xPtr (Subshape.start sh) instance C (Sourced sh) where shape (Sourced (Slice.Cons _s _k slc) _arr) = slc increment (Sourced (Slice.Cons _s k _slc) _arr) = View.elemInc k startArg (Sourced (Slice.Cons s _k _slice) (Array _sh x)) = do xPtr <- ContT $ ForeignArray.withMutablePtr x return (advancePtr xPtr s) sliceArg :: (PrimMonad m, C v, Storable a) => v m slice a -> Call.FortranIO r (Ptr a, Ptr CInt) sliceArg x = liftA2 (,) (startArg x) (Call.cint $ increment x) new :: (PrimMonad m, Shape.C sh, Class.Floating a) => sh -> a -> m (Vector m sh a) new sh x = Array.unsafeCreateWithSize sh $ \size ptr -> fill x size ptr fromChunk :: (PrimMonad m, Storable a) => Chunk.T a -> m (Vector m ShapeInt a) fromChunk = Chunk.toMutableVector -- cf. Vector.Slice.toVector thawSlice :: (PrimMonad m, VectorSlice.C v, Shape.C sh, Class.Floating a) => v sh a -> m (Vector m sh a) thawSlice x = Array.unsafeCreateWithSize (VectorSlice.shape x) $ \n yPtr -> evalContT $ Call.run $ pure Blas.copy <*> Call.cint n <*|> x <*> pure yPtr <*> Call.cint 1 add, sub :: (PrimMonad m, C v, VectorSlice.C w, Shape.C sh, Eq sh, Class.Floating a) => v m sh a -> w sh a -> m () add = flip mac Scalar.one sub = flip mac Scalar.minusOne mac :: (PrimMonad m, C v, VectorSlice.C w, Shape.C sh, Eq sh, Class.Floating a) => v m sh a -> a -> w sh a -> m () mac y alpha x = unsafeIOToPrim $ do let sh = VectorSlice.shape x Call.assert "mac: shapes mismatch" (sh == shape y) evalContT $ do let n = Shape.size sh nPtr <- Call.cint n alphaPtr <- Call.number alpha (xPtr, incxPtr) <- VectorSlice.sliceArg x (yPtr, incyPtr) <- sliceArg y liftIO $ Blas.axpy nPtr alphaPtr xPtr incxPtr yPtr incyPtr