{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ViewPatterns #-}
module Internal.ST (
ST, runST,
STVector, newVector, thawVector, freezeVector, runSTVector,
readVector, writeVector, modifyVector, liftSTVector,
STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix,
readMatrix, writeMatrix, modifyMatrix, liftSTMatrix,
mutable, extractMatrix, setMatrix, rowOper, RowOper(..), RowRange(..), ColRange(..), gemmm, Slice(..),
newUndefinedVector,
unsafeReadVector, unsafeWriteVector,
unsafeThawVector, unsafeFreezeVector,
newUndefinedMatrix,
unsafeReadMatrix, unsafeWriteMatrix,
unsafeThawMatrix, unsafeFreezeMatrix
) where
import Internal.Vector
import Internal.Matrix
import Internal.Vectorized
import Control.Monad.ST(ST, runST)
import Foreign.Storable(Storable, peekElemOff, pokeElemOff)
import Control.Monad.ST.Unsafe(unsafeIOToST)
{-# INLINE ioReadV #-}
ioReadV :: Storable t => Vector t -> Int -> IO t
ioReadV v k = unsafeWith v $ \s -> peekElemOff s k
{-# INLINE ioWriteV #-}
ioWriteV :: Storable t => Vector t -> Int -> t -> IO ()
ioWriteV v k x = unsafeWith v $ \s -> pokeElemOff s k x
newtype STVector s t = STVector (Vector t)
thawVector :: Storable t => Vector t -> ST s (STVector s t)
thawVector = unsafeIOToST . fmap STVector . cloneVector
unsafeThawVector :: Storable t => Vector t -> ST s (STVector s t)
unsafeThawVector = unsafeIOToST . return . STVector
runSTVector :: Storable t => (forall s . ST s (STVector s t)) -> Vector t
runSTVector st = runST (st >>= unsafeFreezeVector)
{-# INLINE unsafeReadVector #-}
unsafeReadVector :: Storable t => STVector s t -> Int -> ST s t
unsafeReadVector (STVector x) = unsafeIOToST . ioReadV x
{-# INLINE unsafeWriteVector #-}
unsafeWriteVector :: Storable t => STVector s t -> Int -> t -> ST s ()
unsafeWriteVector (STVector x) k = unsafeIOToST . ioWriteV x k
{-# INLINE modifyVector #-}
modifyVector :: (Storable t) => STVector s t -> Int -> (t -> t) -> ST s ()
modifyVector x k f = readVector x k >>= return . f >>= unsafeWriteVector x k
liftSTVector :: (Storable t) => (Vector t -> a) -> STVector s t -> ST s a
liftSTVector f (STVector x) = unsafeIOToST . fmap f . cloneVector $ x
freezeVector :: (Storable t) => STVector s t -> ST s (Vector t)
freezeVector v = liftSTVector id v
unsafeFreezeVector :: (Storable t) => STVector s t -> ST s (Vector t)
unsafeFreezeVector (STVector x) = unsafeIOToST . return $ x
{-# INLINE safeIndexV #-}
safeIndexV f (STVector v) k
| k < 0 || k>= dim v = error $ "out of range error in vector (dim="
++show (dim v)++", pos="++show k++")"
| otherwise = f (STVector v) k
{-# INLINE readVector #-}
readVector :: Storable t => STVector s t -> Int -> ST s t
readVector = safeIndexV unsafeReadVector
{-# INLINE writeVector #-}
writeVector :: Storable t => STVector s t -> Int -> t -> ST s ()
writeVector = safeIndexV unsafeWriteVector
newUndefinedVector :: Storable t => Int -> ST s (STVector s t)
newUndefinedVector = unsafeIOToST . fmap STVector . createVector
{-# INLINE newVector #-}
newVector :: Storable t => t -> Int -> ST s (STVector s t)
newVector x n = do
v <- newUndefinedVector n
let go (-1) = return v
go !k = unsafeWriteVector v k x >> go (k-1 :: Int)
go (n-1)
{-# INLINE ioReadM #-}
ioReadM :: Storable t => Matrix t -> Int -> Int -> IO t
ioReadM m r c = ioReadV (xdat m) (r * xRow m + c * xCol m)
{-# INLINE ioWriteM #-}
ioWriteM :: Storable t => Matrix t -> Int -> Int -> t -> IO ()
ioWriteM m r c val = ioWriteV (xdat m) (r * xRow m + c * xCol m) val
newtype STMatrix s t = STMatrix (Matrix t)
thawMatrix :: Element t => Matrix t -> ST s (STMatrix s t)
thawMatrix = unsafeIOToST . fmap STMatrix . cloneMatrix
unsafeThawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t)
unsafeThawMatrix = unsafeIOToST . return . STMatrix
runSTMatrix :: Storable t => (forall s . ST s (STMatrix s t)) -> Matrix t
runSTMatrix st = runST (st >>= unsafeFreezeMatrix)
{-# INLINE unsafeReadMatrix #-}
unsafeReadMatrix :: Storable t => STMatrix s t -> Int -> Int -> ST s t
unsafeReadMatrix (STMatrix x) r = unsafeIOToST . ioReadM x r
{-# INLINE unsafeWriteMatrix #-}
unsafeWriteMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s ()
unsafeWriteMatrix (STMatrix x) r c = unsafeIOToST . ioWriteM x r c
{-# INLINE modifyMatrix #-}
modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s ()
modifyMatrix x r c f = readMatrix x r c >>= return . f >>= unsafeWriteMatrix x r c
liftSTMatrix :: (Element t) => (Matrix t -> a) -> STMatrix s t -> ST s a
liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x
unsafeFreezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t)
unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x
freezeMatrix :: (Element t) => STMatrix s t -> ST s (Matrix t)
freezeMatrix m = liftSTMatrix id m
cloneMatrix m = copy (orderOf m) m
{-# INLINE safeIndexM #-}
safeIndexM f (STMatrix m) r c
| r<0 || r>=rows m ||
c<0 || c>=cols m = error $ "out of range error in matrix (size="
++show (rows m,cols m)++", pos="++show (r,c)++")"
| otherwise = f (STMatrix m) r c
{-# INLINE readMatrix #-}
readMatrix :: Storable t => STMatrix s t -> Int -> Int -> ST s t
readMatrix = safeIndexM unsafeReadMatrix
{-# INLINE writeMatrix #-}
writeMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s ()
writeMatrix = safeIndexM unsafeWriteMatrix
setMatrix :: Element t => STMatrix s t -> Int -> Int -> Matrix t -> ST s ()
setMatrix (STMatrix x) i j m = unsafeIOToST $ setRect i j m x
newUndefinedMatrix :: Storable t => MatrixOrder -> Int -> Int -> ST s (STMatrix s t)
newUndefinedMatrix ord r c = unsafeIOToST $ fmap STMatrix $ createMatrix ord r c
{-# NOINLINE newMatrix #-}
newMatrix :: Storable t => t -> Int -> Int -> ST s (STMatrix s t)
newMatrix v r c = unsafeThawMatrix $ reshape c $ runSTVector $ newVector v (r*c)
data ColRange = AllCols
| ColRange Int Int
| Col Int
| FromCol Int
getColRange c AllCols = (0,c-1)
getColRange c (ColRange a b) = (a `mod` c, b `mod` c)
getColRange c (Col a) = (a `mod` c, a `mod` c)
getColRange c (FromCol a) = (a `mod` c, c-1)
data RowRange = AllRows
| RowRange Int Int
| Row Int
| FromRow Int
getRowRange r AllRows = (0,r-1)
getRowRange r (RowRange a b) = (a `mod` r, b `mod` r)
getRowRange r (Row a) = (a `mod` r, a `mod` r)
getRowRange r (FromRow a) = (a `mod` r, r-1)
data RowOper t = AXPY t Int Int ColRange
| SCAL t RowRange ColRange
| SWAP Int Int ColRange
rowOper :: (Num t, Element t) => RowOper t -> STMatrix s t -> ST s ()
rowOper (AXPY x i1 i2 r) (STMatrix m) = unsafeIOToST $ rowOp 0 x i1' i2' j1 j2 m
where
(j1,j2) = getColRange (cols m) r
i1' = i1 `mod` (rows m)
i2' = i2 `mod` (rows m)
rowOper (SCAL x rr rc) (STMatrix m) = unsafeIOToST $ rowOp 1 x i1 i2 j1 j2 m
where
(i1,i2) = getRowRange (rows m) rr
(j1,j2) = getColRange (cols m) rc
rowOper (SWAP i1 i2 r) (STMatrix m) = unsafeIOToST $ rowOp 2 0 i1' i2' j1 j2 m
where
(j1,j2) = getColRange (cols m) r
i1' = i1 `mod` (rows m)
i2' = i2 `mod` (rows m)
extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR (orderOf m) m 0 (idxs[i1,i2]) 0 (idxs[j1,j2]))
where
(i1,i2) = getRowRange (rows m) rr
(j1,j2) = getColRange (cols m) rc
data Slice s t = Slice (STMatrix s t) Int Int Int Int
slice (Slice (STMatrix m) r0 c0 nr nc) = subMatrix (r0,c0) (nr,nc) m
gemmm :: Element t => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s ()
gemmm beta (slice->r) alpha (slice->a) (slice->b) = res
where
res = unsafeIOToST (gemm v a b r)
v = fromList [alpha,beta]
mutable :: Element t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u)
mutable f a = runST $ do
x <- thawMatrix a
info <- f (rows a, cols a) x
r <- unsafeFreezeMatrix x
return (r,info)