{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE Rank2Types #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE InstanceSigs #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE StrictData #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE BangPatterns #-} {-# LANGUAGE TypeFamilies #-} module Data.Matrix.Static.Sparse ( -- * Sparse matrix SparseMatrix(..) , Zero(..) -- * Accessors -- ** length information , C.dim , C.rows , C.cols -- ** Query , (C.!) , C.takeDiag -- ** Unsafe Query , C.unsafeIndex , C.unsafeTakeRow , C.unsafeTakeColumn , unsafeTakeColumnC -- * Construction , C.empty , fromTriplet , fromTripletC , toTriplet , C.fromVector , C.fromList , C.unsafeFromVector , diag , diagRect -- * Conversions , toDense , C.flatten , C.toList -- * Different matrix types , C.convertAny ) where import qualified Data.Vector as V import qualified Data.Vector.Generic as G import qualified Data.Vector.Generic.Mutable as GM import qualified Data.Vector.Storable as S import qualified Data.Vector.Storable.Mutable as SM import Data.Singletons import Control.Monad.ST (runST) import Data.Bits (shiftR) import Text.Printf (printf) import Conduit import Data.Conduit.Internal (zipSinks) import Data.Tuple (swap) import GHC.TypeLits (type (<=)) import Foreign.C.Types import Data.Complex import Data.Store (Store(..), Size(..)) import Foreign.Storable (sizeOf) import qualified Data.Matrix.Static.Dense as D import qualified Data.Matrix.Static.Dense.Mutable as DM import qualified Data.Matrix.Static.Generic as C import Data.Matrix.Static.Sparse.Mutable type instance C.Mutable SparseMatrix = MSparseMatrix class Eq a => Zero a where zero :: a instance Zero Int where zero = 0 instance Zero Float where zero = 0.0 instance Zero CFloat where zero = 0.0 instance Zero Double where zero = 0.0 instance Zero (Complex Float) where zero = 0 instance Zero (Complex Double) where zero = 0 instance Eq a => Zero ([] a) where zero = [] -- | Column-major mutable matrix. data SparseMatrix :: C.MatrixKind where SparseMatrix :: (SingI r, SingI c) => (v a) -- ^ Values: stores the coefficient values -- of the non-zeros. -> (S.Vector CInt) -- ^ InnerIndices: stores the row -- (resp. column) indices of the non-zeros. -> (S.Vector CInt) -- ^ OuterStarts: stores for each column -- (resp. row) the index of the first -- non-zero in the previous two arrays. -> SparseMatrix r c v a instance (G.Vector v a, Zero a, Store (v a), SingI r, SingI c) => Store (SparseMatrix r c v a) where size = VarSize $ \(SparseMatrix nnz inner outer) -> case (size, size) of (VarSize f, VarSize g) -> 2 * sizeOf (0 :: Int) + f nnz + g inner + g outer _ -> undefined poke mat@(SparseMatrix nnz inner outer) = poke r >> poke c >> poke nnz >> poke inner >> poke outer where (r,c) = C.dim mat peek = do r' <- peek c' <- peek if r' /= r || c' /= c then error $ "Dimensions donot match: " <> show (r,c) <> " /= " <> show (r',c') else SparseMatrix <$> peek <*> peek <*> peek where r = fromIntegral $ fromSing (sing :: Sing r) :: Int c = fromIntegral $ fromSing (sing :: Sing c) :: Int instance (G.Vector v a, Eq (v a)) => Eq (SparseMatrix r c v a) where (==) (SparseMatrix a b c) (SparseMatrix a' b' c') = a == a' && b == b' && c == c' instance (G.Vector v a, Zero a, Show a) => Show (SparseMatrix r c v a) where show mat = printf "(%d x %d)\n%s" r c vals where (r,c) = C.dim mat vals = unlines $ map (unwords . map show . G.toList) $ C.toRows mat instance (G.Vector v a, Zero a) => C.Matrix SparseMatrix v a where -- | O(1) Return the size of matrix. dim :: forall r c. SparseMatrix r c v a -> (Int, Int) dim (SparseMatrix _ _ _) = (r,c) where r = fromIntegral $ fromSing (sing :: Sing r) c = fromIntegral $ fromSing (sing :: Sing c) {-# INLINE dim #-} -- | O(1) Unsafe indexing without bound check. unsafeIndex (SparseMatrix vec inner outer) (i,j) = case binarySearchByBounds inner (fromIntegral i) r0 r1 of Nothing -> zero Just k -> vec `G.unsafeIndex` k where r0 = fromIntegral $ outer `S.unsafeIndex` j r1 = fromIntegral $ outer `S.unsafeIndex` (j+1) - 1 {-# INLINE unsafeIndex #-} unsafeTakeColumn mat i = G.create $ do vec <- GM.replicate (C.rows mat) zero let f (r,_,v) = GM.unsafeWrite vec r v runConduit $ unsafeTakeColumnC mat i .| mapM_C f return vec {-# INLINE unsafeTakeColumn #-} -- | O(1) Create matrix from vector containing columns. unsafeFromVector :: forall r c. (G.Vector v a, SingI r, SingI c) => v a -> SparseMatrix r c v a unsafeFromVector vec = fromTriplet vec' where vec' = V.fromList $ map (\((a,b),c) -> (a,b,c)) $ filter ((/=zero) . snd) $ zipWith (\i x -> (toIndex i, x)) [0..] $ G.toList vec toIndex i = swap $ i `divMod` r r = fromIntegral $ fromSing (sing :: Sing r) {-# INLINE unsafeFromVector #-} transpose mat = runIdentity $ fromTripletC source where source = toTriplet mat .| mapC (\(i,j,x) -> (j,i,x)) {-# INLINE transpose #-} thaw = undefined {-# INLINE thaw #-} unsafeThaw = undefined {-# INLINE unsafeThaw #-} freeze = undefined {-# INLINE freeze #-} unsafeFreeze = undefined {-# INLINE unsafeFreeze #-} map f (SparseMatrix vec inner outer) = SparseMatrix (G.map f vec) inner outer {-# INLINE map #-} imap f mat@(SparseMatrix _ inner outer) = SparseMatrix vec' inner outer where vec' = runST $ runConduit $ toTriplet mat .| mapC g .| sinkVector g (i,j,x) = f (i,j) x {-# INLINE imap #-} imapM_ f mat@(SparseMatrix _ _ _) = runConduit $ toTriplet mat .| mapM_C g where g (i,j,x) = f (i,j) x >> return () {-# INLINE imapM_ #-} sequence (SparseMatrix vec inner outer) = do vec' <- G.sequence vec return $ SparseMatrix vec' inner outer {-# INLINE sequence #-} sequence_ (SparseMatrix vec _ _) = G.sequence_ vec {-# INLINE sequence_ #-} toDense :: (Zero a, G.Vector v a, SingI r, SingI c) => SparseMatrix r c v a -> D.Matrix r c v a toDense mat = D.create $ do m <- DM.replicate zero flip C.imapM_ mat $ \idx -> DM.unsafeWrite m idx return m {-# INLINE toDense #-} -- | Stream a column. unsafeTakeColumnC :: (Monad m, G.Vector v a) => SparseMatrix r c v a -> Int -> ConduitT i (Int, Int, a) m () unsafeTakeColumnC (SparseMatrix nnz inner outer) i = enumFromToC lo hi .| mapC f where f idx = (fromIntegral $ inner `S.unsafeIndex` idx, i, nnz `G.unsafeIndex` idx) lo = fromIntegral $ outer S.! i hi = fromIntegral $ outer S.! (i+1) - 1 {-# INLINE unsafeTakeColumnC #-} -- | O(n) Create matrix from triplet. row and column indices *are not* assumed to be ordered -- duplicate entries are carried over to the CSR represention fromTriplet :: forall u r c v a. (G.Vector u (Int, Int, a), G.Vector v a, SingI r, SingI c) => u (Int, Int, a) -> SparseMatrix r c v a fromTriplet triplets = SparseMatrix val inner outer where outer = S.scanl (+) 0 $ S.create $ do vec <- SM.replicate c 0 G.forM_ triplets $ \(i, j, _) -> if i < r && j < c then SM.unsafeModify vec (+1) j else error $ printf "Index out of bound: (%d, %d) >= (%d, %d)" i j r c return vec (val, inner) = runST $ do outer' <- S.thaw outer val' <- GM.new nnz inner' <- SM.new nnz G.forM_ triplets $ \(i, j, v) -> do idx <- fromIntegral <$> SM.unsafeRead outer' j GM.unsafeWrite val' idx v SM.unsafeWrite inner' idx $ fromIntegral i SM.unsafeModify outer' (+1) j (,) <$> G.unsafeFreeze val' <*> S.unsafeFreeze inner' nnz = G.length triplets r = fromIntegral $ fromSing (sing :: Sing r) c = fromIntegral $ fromSing (sing :: Sing c) {-# INLINE fromTriplet #-} -- | O(n) Create matrix from triplet. Row and column indices *are not* assumed -- to be ordered. Duplicate entries are carried over to the CSR represention. -- NOTE: The Conduit will be consumed twice. Use `fromTriplet` if generating -- the Conduit is expensive. fromTripletC :: forall m r c v a. (Monad m, G.Vector v a, SingI r, SingI c) => ConduitT () (Int, Int, a) m () -> m (SparseMatrix r c v a) fromTripletC triplets = do (nnz, outer) <- runConduit $ triplets .| zipSinks lengthC sinkOuter (val, inner, _) <- runConduit $ triplets .| sinkValInner nnz (clone outer) return $ SparseMatrix val inner outer where sinkOuter = S.scanl (+) 0 <$> foldlC f (S.replicate c 0) where f vec (_, j, _) = S.modify (\v -> SM.unsafeModify v (+1) j) vec sinkValInner nnz outer0 = foldlC f (val0, inner0, outer0) where val0 = G.create $ GM.new nnz inner0 = S.create $ SM.new nnz f (val, inner, outer) (i, j, v) = (val', inner', outer') where idx = fromIntegral $ outer `S.unsafeIndex` j val' = G.create $ do vec <- G.unsafeThaw val GM.unsafeWrite vec idx v return vec inner' = S.create $ do vec <- S.unsafeThaw inner SM.unsafeWrite vec idx $ fromIntegral i return vec outer' = S.create $ do vec <- S.unsafeThaw outer SM.unsafeModify vec (+1) j return vec c = fromIntegral $ fromSing (sing :: Sing c) clone x = S.create $ S.thaw x {-# INLINE fromTripletC #-} -- | Convert sparse matrix to triplets in column order. toTriplet :: (Monad m, G.Vector v a, SingI r, SingI c) => SparseMatrix r c v a -> ConduitT i (Int, Int, a) m () toTriplet (SparseMatrix val inner outer) = G.ifoldM_ go (fromIntegral $ G.head outer) outer where go start curC end = do enumFromToC start (end'-1) .| mapC f return end' where end' = fromIntegral end f i = (fromIntegral $ inner `G.unsafeIndex` i, fromIntegral curC - 1, val `G.unsafeIndex` i) {-# INLINE toTriplet #-} -- | O(m*n) Create a rectangular matrix with default values and given diagonal diag :: (G.Vector v a, Zero a, SingI n) => D.Matrix n 1 v a -- ^ diagonal -> SparseMatrix n n v a diag = diagRect {-# INLINE diag #-} -- | O(m*n) Create a rectangular matrix with default values and given diagonal diagRect :: (G.Vector v a, Zero a, SingI r, SingI c, n <= r, n <= c) => D.Matrix n 1 v a -- ^ diagonal -> SparseMatrix r c v a diagRect d = SparseMatrix (C.flatten d) (S.enumFromN 0 n) (S.enumFromN 0 $ n + 1) where n = C.rows d {-# INLINE diagRect #-} binarySearchByBounds :: S.Vector CInt -> CInt -> Int -> Int -> Maybe Int binarySearchByBounds vec x = loop where loop !l !u | l > u = Nothing | x == x' = Just k | x < x' = loop l (k-1) | otherwise = loop (k+1) u where k = (u+l) `shiftR` 1 x' = vec `S.unsafeIndex` k {-# INLINE binarySearchByBounds #-} ------------------------------------------------------------------------------- -- Helper ------------------------------------------------------------------------------- --getIndex :: Int -> (Int, Int) --getIndex = --{-# INLINE getIndex #-}