{-# LANGUAGE MagicHash, CPP, UnboxedTuples, BangPatterns, FlexibleContexts #-}
{-# LANGUAGE TypeSynonymInstances #-}
module Internal.Vector(
I,Z,R,C,
fi,ti,
Vector, fromList, unsafeToForeignPtr, unsafeFromForeignPtr, unsafeWith,
createVector, avec, inlinePerformIO,
toList, dim, (@>), at', (|>),
vjoin, subVector, takesV, idxs,
buildVector,
asReal, asComplex,
toByteString,fromByteString,
zipVector, unzipVector, zipVectorWith, unzipVectorWith,
foldVector, foldVectorG, foldVectorWithIndex, foldLoop,
mapVector, mapVectorM, mapVectorM_,
mapVectorWithIndex, mapVectorWithIndexM, mapVectorWithIndexM_
) where
import Foreign.Marshal.Array
import Foreign.ForeignPtr
import Foreign.Ptr
import Foreign.Storable
import Foreign.C.Types(CInt)
import Data.Int(Int64)
import Data.Complex
import System.IO.Unsafe(unsafePerformIO)
import GHC.ForeignPtr(mallocPlainForeignPtrBytes)
import GHC.Base(realWorld#, IO(IO), when)
import qualified Data.Vector.Storable as Vector
import Data.Vector.Storable(Vector, fromList, unsafeToForeignPtr, unsafeFromForeignPtr, unsafeWith)
#ifdef BINARY
import Data.Binary
import Control.Monad(replicateM)
import qualified Data.ByteString.Internal as BS
import Data.Vector.Storable.Internal(updPtr)
#endif
type I = CInt
type Z = Int64
type R = Double
type C = Complex Double
fi :: Int -> CInt
fi = fromIntegral
ti :: CInt -> Int
ti = fromIntegral
dim :: (Storable t) => Vector t -> Int
dim = Vector.length
{-# INLINE avec #-}
avec :: Storable a => (CInt -> Ptr a -> b) -> Vector a -> b
avec f v = inlinePerformIO (unsafeWith v (return . f (fromIntegral (Vector.length v))))
infixl 1 `avec`
createVector :: Storable a => Int -> IO (Vector a)
createVector n = do
when (n < 0) $ error ("trying to createVector of negative dim: "++show n)
fp <- doMalloc undefined
return $ unsafeFromForeignPtr fp 0 n
where
doMalloc :: Storable b => b -> IO (ForeignPtr b)
doMalloc dummy = do
mallocPlainForeignPtrBytes (n * sizeOf dummy)
safeRead v = inlinePerformIO . unsafeWith v
{-# INLINE safeRead #-}
inlinePerformIO :: IO a -> a
inlinePerformIO (IO m) = case m realWorld# of (# _, r #) -> r
{-# INLINE inlinePerformIO #-}
toList :: Storable a => Vector a -> [a]
toList v = safeRead v $ peekArray (dim v)
(|>) :: (Storable a) => Int -> [a] -> Vector a
infixl 9 |>
n |> l
| length l' == n = fromList l'
| otherwise = error "list too short for |>"
where
l' = take n l
idxs :: [Int] -> Vector I
idxs js = fromList (map fromIntegral js) :: Vector I
subVector :: Storable t => Int
-> Int
-> Vector t
-> Vector t
subVector = Vector.slice
(@>) :: Storable t => Vector t -> Int -> t
infixl 9 @>
v @> n
| n >= 0 && n < dim v = at' v n
| otherwise = error "vector index out of range"
{-# INLINE (@>) #-}
at' :: Storable a => Vector a -> Int -> a
at' v n = safeRead v $ flip peekElemOff n
{-# INLINE at' #-}
vjoin :: Storable t => [Vector t] -> Vector t
vjoin [] = fromList []
vjoin [v] = v
vjoin as = unsafePerformIO $ do
let tot = sum (map dim as)
r <- createVector tot
unsafeWith r $ \ptr ->
joiner as tot ptr
return r
where joiner [] _ _ = return ()
joiner (v:cs) _ p = do
let n = dim v
unsafeWith v $ \pb -> copyArray p pb n
joiner cs 0 (advancePtr p n)
takesV :: Storable t => [Int] -> Vector t -> [Vector t]
takesV ms w | sum ms > dim w = error $ "takesV " ++ show ms ++ " on dim = " ++ (show $ dim w)
| otherwise = go ms w
where go [] _ = []
go (n:ns) v = subVector 0 n v
: go ns (subVector n (dim v - n) v)
asReal :: (RealFloat a, Storable a) => Vector (Complex a) -> Vector a
asReal v = unsafeFromForeignPtr (castForeignPtr fp) (2*i) (2*n)
where (fp,i,n) = unsafeToForeignPtr v
asComplex :: (RealFloat a, Storable a) => Vector a -> Vector (Complex a)
asComplex v = unsafeFromForeignPtr (castForeignPtr fp) (i `div` 2) (n `div` 2)
where (fp,i,n) = unsafeToForeignPtr v
mapVector :: (Storable a, Storable b) => (a-> b) -> Vector a -> Vector b
mapVector f v = unsafePerformIO $ do
w <- createVector (dim v)
unsafeWith v $ \p ->
unsafeWith w $ \q -> do
let go (-1) = return ()
go !k = do x <- peekElemOff p k
pokeElemOff q k (f x)
go (k-1)
go (dim v -1)
return w
{-# INLINE mapVector #-}
zipVectorWith :: (Storable a, Storable b, Storable c) => (a-> b -> c) -> Vector a -> Vector b -> Vector c
zipVectorWith f u v = unsafePerformIO $ do
let n = min (dim u) (dim v)
w <- createVector n
unsafeWith u $ \pu ->
unsafeWith v $ \pv ->
unsafeWith w $ \pw -> do
let go (-1) = return ()
go !k = do x <- peekElemOff pu k
y <- peekElemOff pv k
pokeElemOff pw k (f x y)
go (k-1)
go (n -1)
return w
{-# INLINE zipVectorWith #-}
unzipVectorWith :: (Storable (a,b), Storable c, Storable d)
=> ((a,b) -> (c,d)) -> Vector (a,b) -> (Vector c,Vector d)
unzipVectorWith f u = unsafePerformIO $ do
let n = dim u
v <- createVector n
w <- createVector n
unsafeWith u $ \pu ->
unsafeWith v $ \pv ->
unsafeWith w $ \pw -> do
let go (-1) = return ()
go !k = do z <- peekElemOff pu k
let (x,y) = f z
pokeElemOff pv k x
pokeElemOff pw k y
go (k-1)
go (n-1)
return (v,w)
{-# INLINE unzipVectorWith #-}
foldVector :: Storable a => (a -> b -> b) -> b -> Vector a -> b
foldVector f x v = unsafePerformIO $
unsafeWith v $ \p -> do
let go (-1) s = return s
go !k !s = do y <- peekElemOff p k
go (k-1::Int) (f y s)
go (dim v -1) x
{-# INLINE foldVector #-}
foldVectorWithIndex :: Storable a => (Int -> a -> b -> b) -> b -> Vector a -> b
foldVectorWithIndex f x v = unsafePerformIO $
unsafeWith v $ \p -> do
let go (-1) s = return s
go !k !s = do y <- peekElemOff p k
go (k-1::Int) (f k y s)
go (dim v -1) x
{-# INLINE foldVectorWithIndex #-}
foldLoop f s0 d = go (d - 1) s0
where
go 0 s = f (0::Int) s
go !j !s = go (j - 1) (f j s)
foldVectorG f s0 v = foldLoop g s0 (dim v)
where g !k !s = f k (safeRead v . flip peekElemOff) s
{-# INLINE g #-}
{-# INLINE foldVectorG #-}
mapVectorM :: (Storable a, Storable b, Monad m) => (a -> m b) -> Vector a -> m (Vector b)
mapVectorM f v = do
w <- return $! unsafePerformIO $! createVector (dim v)
mapVectorM' w 0 (dim v -1)
return w
where mapVectorM' w' !k !t
| k == t = do
x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k
y <- f x
return $! inlinePerformIO $! unsafeWith w' $! \q -> pokeElemOff q k y
| otherwise = do
x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k
y <- f x
_ <- return $! inlinePerformIO $! unsafeWith w' $! \q -> pokeElemOff q k y
mapVectorM' w' (k+1) t
{-# INLINE mapVectorM #-}
mapVectorM_ :: (Storable a, Monad m) => (a -> m ()) -> Vector a -> m ()
mapVectorM_ f v = do
mapVectorM' 0 (dim v -1)
where mapVectorM' !k !t
| k == t = do
x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k
f x
| otherwise = do
x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k
_ <- f x
mapVectorM' (k+1) t
{-# INLINE mapVectorM_ #-}
mapVectorWithIndexM :: (Storable a, Storable b, Monad m) => (Int -> a -> m b) -> Vector a -> m (Vector b)
mapVectorWithIndexM f v = do
w <- return $! unsafePerformIO $! createVector (dim v)
mapVectorM' w 0 (dim v -1)
return w
where mapVectorM' w' !k !t
| k == t = do
x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k
y <- f k x
return $! inlinePerformIO $! unsafeWith w' $! \q -> pokeElemOff q k y
| otherwise = do
x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k
y <- f k x
_ <- return $! inlinePerformIO $! unsafeWith w' $! \q -> pokeElemOff q k y
mapVectorM' w' (k+1) t
{-# INLINE mapVectorWithIndexM #-}
mapVectorWithIndexM_ :: (Storable a, Monad m) => (Int -> a -> m ()) -> Vector a -> m ()
mapVectorWithIndexM_ f v = do
mapVectorM' 0 (dim v -1)
where mapVectorM' !k !t
| k == t = do
x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k
f k x
| otherwise = do
x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k
_ <- f k x
mapVectorM' (k+1) t
{-# INLINE mapVectorWithIndexM_ #-}
mapVectorWithIndex :: (Storable a, Storable b) => (Int -> a -> b) -> Vector a -> Vector b
mapVectorWithIndex f v = unsafePerformIO $ do
w <- createVector (dim v)
unsafeWith v $ \p ->
unsafeWith w $ \q -> do
let go (-1) = return ()
go !k = do x <- peekElemOff p k
pokeElemOff q k (f k x)
go (k-1)
go (dim v -1)
return w
{-# INLINE mapVectorWithIndex #-}
#ifdef BINARY
chunk :: Int
chunk = 5000
chunks :: Int -> [Int]
chunks d = let c = d `div` chunk
m = d `mod` chunk
in if m /= 0 then reverse (m:(replicate c chunk)) else (replicate c chunk)
putVector v = mapM_ put $! toList v
getVector d = do
xs <- replicateM d get
return $! fromList xs
toByteString :: Storable t => Vector t -> BS.ByteString
toByteString v = BS.PS (castForeignPtr fp) (sz*o) (sz * dim v)
where
(fp,o,_n) = unsafeToForeignPtr v
sz = sizeOf (v@>0)
fromByteString :: Storable t => BS.ByteString -> Vector t
fromByteString (BS.PS fp o n) = r
where
r = unsafeFromForeignPtr (castForeignPtr (updPtr (`plusPtr` o) fp)) 0 n'
n' = n `div` sz
sz = sizeOf (r@>0)
instance (Binary a, Storable a) => Binary (Vector a) where
put v = do
let d = dim v
put d
mapM_ putVector $! takesV (chunks d) v
get = do
d <- get
vs <- mapM getVector $ chunks d
return $! vjoin vs
#endif
buildVector :: Storable a => Int -> (Int -> a) -> Vector a
buildVector len f =
fromList $ map f [0 .. (len - 1)]
zipVector :: (Storable a, Storable b, Storable (a,b)) => Vector a -> Vector b -> Vector (a,b)
zipVector = zipVectorWith (,)
unzipVector :: (Storable a, Storable b, Storable (a,b)) => Vector (a,b) -> (Vector a,Vector b)
unzipVector = unzipVectorWith id