{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}

module Internal.Sparse(
    GMatrix(..), CSR(..), mkCSR, fromCSR, impureCSR,
    mkSparse, mkDiagR, mkDense,
    AssocMatrix,
    toDense,
    gmXv, (!#>)
)where

import Internal.Vector
import Internal.Matrix
import Internal.Numeric
import qualified Data.Vector.Storable as V
import qualified Data.Vector.Storable.Mutable as M
import Control.Arrow((***))
import Control.Monad(when, foldM)
import Control.Monad.ST (runST)
import Control.Monad.Primitive (PrimMonad)
import Data.List(sort)
import Foreign.C.Types(CInt(..))

import Internal.Devel
import System.IO.Unsafe(unsafePerformIO)
import Foreign(Ptr)
import Text.Printf(printf)

type AssocMatrix = [(IndexOf Matrix, Double)]

data CSR = CSR
        { CSR -> Vector Double
csrVals  :: Vector Double
        , CSR -> Vector CInt
csrCols  :: Vector CInt
        , CSR -> Vector CInt
csrRows  :: Vector CInt
        , CSR -> Int
csrNRows :: Int
        , CSR -> Int
csrNCols :: Int
        } deriving Int -> CSR -> ShowS
[CSR] -> ShowS
CSR -> String
(Int -> CSR -> ShowS)
-> (CSR -> String) -> ([CSR] -> ShowS) -> Show CSR
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [CSR] -> ShowS
$cshowList :: [CSR] -> ShowS
show :: CSR -> String
$cshow :: CSR -> String
showsPrec :: Int -> CSR -> ShowS
$cshowsPrec :: Int -> CSR -> ShowS
Show

data CSC = CSC
        { CSC -> Vector Double
cscVals  :: Vector Double
        , CSC -> Vector CInt
cscRows  :: Vector CInt
        , CSC -> Vector CInt
cscCols  :: Vector CInt
        , CSC -> Int
cscNRows :: Int
        , CSC -> Int
cscNCols :: Int
        } deriving Int -> CSC -> ShowS
[CSC] -> ShowS
CSC -> String
(Int -> CSC -> ShowS)
-> (CSC -> String) -> ([CSC] -> ShowS) -> Show CSC
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [CSC] -> ShowS
$cshowList :: [CSC] -> ShowS
show :: CSC -> String
$cshow :: CSC -> String
showsPrec :: Int -> CSC -> ShowS
$cshowsPrec :: Int -> CSC -> ShowS
Show


-- | Produce a CSR sparse matrix from a association matrix.
mkCSR :: AssocMatrix -> CSR
mkCSR :: AssocMatrix -> CSR
mkCSR AssocMatrix
ms =
  (forall s. ST s CSR) -> CSR
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s CSR) -> CSR) -> (forall s. ST s CSR) -> CSR
forall a b. (a -> b) -> a -> b
$ (forall x.
 (x -> (IndexOf Matrix, Double) -> ST s x)
 -> ST s x -> (x -> ST s CSR) -> [((Int, Int), Double)] -> ST s CSR)
-> [((Int, Int), Double)] -> ST s CSR
forall (m :: * -> *) r.
PrimMonad m =>
(forall x.
 (x -> (IndexOf Matrix, Double) -> m x) -> m x -> (x -> m CSR) -> r)
-> r
impureCSR forall x.
(x -> (IndexOf Matrix, Double) -> ST s x)
-> ST s x -> (x -> ST s CSR) -> [((Int, Int), Double)] -> ST s CSR
forall (m :: * -> *) (t :: * -> *) t a b.
(Monad m, Foldable t) =>
(t -> a -> m t) -> m t -> (t -> m b) -> t a -> m b
runFold ([((Int, Int), Double)] -> ST s CSR)
-> [((Int, Int), Double)] -> ST s CSR
forall a b. (a -> b) -> a -> b
$ [((Int, Int), Double)] -> [((Int, Int), Double)]
forall a. Ord a => [a] -> [a]
sort [((Int, Int), Double)]
AssocMatrix
ms
    where
  runFold :: (t -> a -> m t) -> m t -> (t -> m b) -> t a -> m b
runFold t -> a -> m t
next m t
initialise t -> m b
xtract t a
as0 = do
    t
i0  <- m t
initialise
    t
acc <- (t -> a -> m t) -> t -> t a -> m t
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM t -> a -> m t
next t
i0 t a
as0
    t -> m b
xtract t
acc

-- | Produce a CSR sparse matrix by applying a generic folding function.
--
--   This allows one to build a CSR from an effectful streaming source
--   when combined with libraries like pipes, io-streams, or streaming.
--
--   For example
--
--   > impureCSR Pipes.Prelude.foldM :: PrimMonad m => Producer AssocEntry m () -> m CSR
--   > impureCSR Streaming.Prelude.foldM :: PrimMonad m => Stream (Of AssocEntry) m r -> m (Of CSR r)
--
impureCSR
    :: PrimMonad m
    => (forall x . (x -> (IndexOf Matrix, Double) -> m x) -> m x -> (x -> m CSR) -> r)
    -> r
impureCSR :: (forall x.
 (x -> (IndexOf Matrix, Double) -> m x) -> m x -> (x -> m CSR) -> r)
-> r
impureCSR forall x.
(x -> (IndexOf Matrix, Double) -> m x) -> m x -> (x -> m CSR) -> r
f = ((MVector (PrimState m) Double, MVector (PrimState m) CInt,
  MVector (PrimState m) CInt, Int, Int, Int, Int)
 -> (IndexOf Matrix, Double)
 -> m (MVector (PrimState m) Double, MVector (PrimState m) CInt,
       MVector (PrimState m) CInt, Int, Int, Int, Int))
-> m (MVector (PrimState m) Double, MVector (PrimState m) CInt,
      MVector (PrimState m) CInt, Int, Int, Int, Int)
-> ((MVector (PrimState m) Double, MVector (PrimState m) CInt,
     MVector (PrimState m) CInt, Int, Int, Int, Int)
    -> m CSR)
-> r
forall x.
(x -> (IndexOf Matrix, Double) -> m x) -> m x -> (x -> m CSR) -> r
f (MVector (PrimState m) Double, MVector (PrimState m) CInt,
 MVector (PrimState m) CInt, Int, Int, Int, Int)
-> (IndexOf Matrix, Double)
-> m (MVector (PrimState m) Double, MVector (PrimState m) CInt,
      MVector (PrimState m) CInt, Int, Int, Int, Int)
forall (m :: * -> *) g a.
(Ord g, PrintfArg g, PrimMonad m, Num g, Enum g, Storable a) =>
(MVector (PrimState m) a, MVector (PrimState m) CInt,
 MVector (PrimState m) CInt, Int, Int, Int, g)
-> ((g, Int), a)
-> m (MVector (PrimState m) a, MVector (PrimState m) CInt,
      MVector (PrimState m) CInt, Int, Int, Int, g)
next m (MVector (PrimState m) Double, MVector (PrimState m) CInt,
   MVector (PrimState m) CInt, Int, Int, Int, Int)
begin (MVector (PrimState m) Double, MVector (PrimState m) CInt,
 MVector (PrimState m) CInt, Int, Int, Int, Int)
-> m CSR
forall (m :: * -> *).
PrimMonad m =>
(MVector (PrimState m) Double, MVector (PrimState m) CInt,
 MVector (PrimState m) CInt, Int, Int, Int, Int)
-> m CSR
done
  where
    sfi :: Int -> CInt
sfi = CInt -> CInt
forall a. Enum a => a -> a
succ (CInt -> CInt) -> (Int -> CInt) -> Int -> CInt
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> CInt
fi
    begin :: m (MVector (PrimState m) Double, MVector (PrimState m) CInt,
   MVector (PrimState m) CInt, Int, Int, Int, Int)
begin = do
      MVector (PrimState m) Double
mv <- Int -> m (MVector (PrimState m) Double)
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
M.unsafeNew Int
64
      MVector (PrimState m) CInt
mr <- Int -> m (MVector (PrimState m) CInt)
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
M.unsafeNew Int
64
      MVector (PrimState m) CInt
mc <- Int -> m (MVector (PrimState m) CInt)
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
M.unsafeNew Int
64
      (MVector (PrimState m) Double, MVector (PrimState m) CInt,
 MVector (PrimState m) CInt, Int, Int, Int, Int)
-> m (MVector (PrimState m) Double, MVector (PrimState m) CInt,
      MVector (PrimState m) CInt, Int, Int, Int, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (MVector (PrimState m) Double
mv, MVector (PrimState m) CInt
mr, MVector (PrimState m) CInt
mc, Int
0, Int
0, Int
0, -Int
1)

    next :: (MVector (PrimState m) a, MVector (PrimState m) CInt,
 MVector (PrimState m) CInt, Int, Int, Int, g)
-> ((g, Int), a)
-> m (MVector (PrimState m) a, MVector (PrimState m) CInt,
      MVector (PrimState m) CInt, Int, Int, Int, g)
next (!MVector (PrimState m) a
mv, !MVector (PrimState m) CInt
mr, !MVector (PrimState m) CInt
mc, !Int
idxVC, !Int
idxR, !Int
maxC, !g
curRow) ((g
r,Int
c),a
d) = do
      Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (g
r g -> g -> Bool
forall a. Ord a => a -> a -> Bool
< g
curRow) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
        String -> m ()
forall a. HasCallStack => String -> a
error (String -> g -> g -> String
forall r. PrintfType r => String -> r
printf String
"impureCSR: row %i specified after %i" g
r g
curRow)

      let lenVC :: Int
lenVC = MVector (PrimState m) a -> Int
forall a s. Storable a => MVector s a -> Int
M.length MVector (PrimState m) a
mv
          lenR :: Int
lenR  = MVector (PrimState m) CInt -> Int
forall a s. Storable a => MVector s a -> Int
M.length MVector (PrimState m) CInt
mr
          maxC' :: Int
maxC' = Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
maxC Int
c

      (MVector (PrimState m) a
mv', MVector (PrimState m) CInt
mc') <-
        if Int
idxVC Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
lenVC then do
          MVector (PrimState m) a
mv' <- MVector (PrimState m) a -> Int -> m (MVector (PrimState m) a)
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> Int -> m (MVector (PrimState m) a)
M.unsafeGrow MVector (PrimState m) a
mv Int
lenVC
          MVector (PrimState m) CInt
mc' <- MVector (PrimState m) CInt -> Int -> m (MVector (PrimState m) CInt)
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> Int -> m (MVector (PrimState m) a)
M.unsafeGrow MVector (PrimState m) CInt
mc Int
lenVC
          (MVector (PrimState m) a, MVector (PrimState m) CInt)
-> m (MVector (PrimState m) a, MVector (PrimState m) CInt)
forall (m :: * -> *) a. Monad m => a -> m a
return (MVector (PrimState m) a
mv', MVector (PrimState m) CInt
mc')
        else
          (MVector (PrimState m) a, MVector (PrimState m) CInt)
-> m (MVector (PrimState m) a, MVector (PrimState m) CInt)
forall (m :: * -> *) a. Monad m => a -> m a
return (MVector (PrimState m) a
mv, MVector (PrimState m) CInt
mc)

      MVector (PrimState m) CInt
mr' <-
        if Int
idxR Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
lenR Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 then
          MVector (PrimState m) CInt -> Int -> m (MVector (PrimState m) CInt)
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> Int -> m (MVector (PrimState m) a)
M.unsafeGrow MVector (PrimState m) CInt
mr Int
lenR
        else
          MVector (PrimState m) CInt -> m (MVector (PrimState m) CInt)
forall (m :: * -> *) a. Monad m => a -> m a
return MVector (PrimState m) CInt
mr

      MVector (PrimState m) CInt -> Int -> CInt -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> Int -> a -> m ()
M.unsafeWrite MVector (PrimState m) CInt
mc' Int
idxVC (Int -> CInt
sfi Int
c)
      MVector (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> Int -> a -> m ()
M.unsafeWrite MVector (PrimState m) a
mv' Int
idxVC a
d

      Int
idxR' <-
        (Int -> g -> m Int) -> Int -> [g] -> m Int
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM
          (\Int
idxR' g
_ -> Int
idxR' Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 Int -> m () -> m Int
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ MVector (PrimState m) CInt -> Int -> CInt -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> Int -> a -> m ()
M.unsafeWrite MVector (PrimState m) CInt
mr' Int
idxR' (Int -> CInt
sfi Int
idxVC))
          Int
idxR [g
1 .. (g
rg -> g -> g
forall a. Num a => a -> a -> a
-g
curRow)]

      (MVector (PrimState m) a, MVector (PrimState m) CInt,
 MVector (PrimState m) CInt, Int, Int, Int, g)
-> m (MVector (PrimState m) a, MVector (PrimState m) CInt,
      MVector (PrimState m) CInt, Int, Int, Int, g)
forall (m :: * -> *) a. Monad m => a -> m a
return (MVector (PrimState m) a
mv', MVector (PrimState m) CInt
mr', MVector (PrimState m) CInt
mc', Int
idxVC Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, Int
idxR', Int
maxC', g
r)

    done :: (MVector (PrimState m) Double, MVector (PrimState m) CInt,
 MVector (PrimState m) CInt, Int, Int, Int, Int)
-> m CSR
done (!MVector (PrimState m) Double
mv, !MVector (PrimState m) CInt
mr, !MVector (PrimState m) CInt
mc, !Int
idxVC, !Int
idxR, !Int
maxC, !Int
curR) = do
      MVector (PrimState m) CInt -> Int -> CInt -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> Int -> a -> m ()
M.unsafeWrite MVector (PrimState m) CInt
mr Int
idxR (Int -> CInt
sfi Int
idxVC)
      Vector Double
vv <- MVector (PrimState m) Double -> m (Vector Double)
forall a (m :: * -> *).
(Storable a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
V.unsafeFreeze (Int -> MVector (PrimState m) Double -> MVector (PrimState m) Double
forall a s. Storable a => Int -> MVector s a -> MVector s a
M.unsafeTake Int
idxVC MVector (PrimState m) Double
mv)
      Vector CInt
vc <- MVector (PrimState m) CInt -> m (Vector CInt)
forall a (m :: * -> *).
(Storable a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
V.unsafeFreeze (Int -> MVector (PrimState m) CInt -> MVector (PrimState m) CInt
forall a s. Storable a => Int -> MVector s a -> MVector s a
M.unsafeTake Int
idxVC MVector (PrimState m) CInt
mc)
      Vector CInt
vr <- MVector (PrimState m) CInt -> m (Vector CInt)
forall a (m :: * -> *).
(Storable a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
V.unsafeFreeze (Int -> MVector (PrimState m) CInt -> MVector (PrimState m) CInt
forall a s. Storable a => Int -> MVector s a -> MVector s a
M.unsafeTake (Int
idxR Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)  MVector (PrimState m) CInt
mr)
      CSR -> m CSR
forall (m :: * -> *) a. Monad m => a -> m a
return (CSR -> m CSR) -> CSR -> m CSR
forall a b. (a -> b) -> a -> b
$ Vector Double -> Vector CInt -> Vector CInt -> Int -> Int -> CSR
CSR Vector Double
vv Vector CInt
vc Vector CInt
vr (Int -> Int
forall a. Enum a => a -> a
succ Int
curR) (Int -> Int
forall a. Enum a => a -> a
succ Int
maxC)


{- | General matrix with specialized internal representations for
     dense, sparse, diagonal, banded, and constant elements.

>>> let m = mkSparse [((0,999),1.0),((1,1999),2.0)]
>>> m
SparseR {gmCSR = CSR {csrVals = fromList [1.0,2.0],
                      csrCols = fromList [1000,2000],
                      csrRows = fromList [1,2,3],
                      csrNRows = 2,
                      csrNCols = 2000},
                      nRows = 2,
                      nCols = 2000}

>>> let m = mkDense (mat 2 [1..4])
>>> m
Dense {gmDense = (2><2)
 [ 1.0, 2.0
 , 3.0, 4.0 ], nRows = 2, nCols = 2}

-}
data GMatrix
    = SparseR
        { GMatrix -> CSR
gmCSR   :: CSR
        , GMatrix -> Int
nRows   :: Int
        , GMatrix -> Int
nCols   :: Int
        }
    | SparseC
        { GMatrix -> CSC
gmCSC   :: CSC
        , nRows   :: Int
        , nCols   :: Int
        }
    | Diag
        { GMatrix -> Vector Double
diagVals :: Vector Double
        , nRows    :: Int
        , nCols    :: Int
        }
    | Dense
        { GMatrix -> Matrix Double
gmDense :: Matrix Double
        , nRows   :: Int
        , nCols   :: Int
        }
--    | Banded
    deriving Int -> GMatrix -> ShowS
[GMatrix] -> ShowS
GMatrix -> String
(Int -> GMatrix -> ShowS)
-> (GMatrix -> String) -> ([GMatrix] -> ShowS) -> Show GMatrix
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [GMatrix] -> ShowS
$cshowList :: [GMatrix] -> ShowS
show :: GMatrix -> String
$cshow :: GMatrix -> String
showsPrec :: Int -> GMatrix -> ShowS
$cshowsPrec :: Int -> GMatrix -> ShowS
Show


mkDense :: Matrix Double -> GMatrix
mkDense :: Matrix Double -> GMatrix
mkDense Matrix Double
m = Dense :: Matrix Double -> Int -> Int -> GMatrix
Dense{Int
Matrix Double
nCols :: Int
nRows :: Int
gmDense :: Matrix Double
gmDense :: Matrix Double
nCols :: Int
nRows :: Int
..}
  where
    gmDense :: Matrix Double
gmDense = Matrix Double
m
    nRows :: Int
nRows = Matrix Double -> Int
forall t. Matrix t -> Int
rows Matrix Double
m
    nCols :: Int
nCols = Matrix Double -> Int
forall t. Matrix t -> Int
cols Matrix Double
m

mkSparse :: AssocMatrix -> GMatrix
mkSparse :: AssocMatrix -> GMatrix
mkSparse = CSR -> GMatrix
fromCSR (CSR -> GMatrix)
-> ([((Int, Int), Double)] -> CSR)
-> [((Int, Int), Double)]
-> GMatrix
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [((Int, Int), Double)] -> CSR
AssocMatrix -> CSR
mkCSR

fromCSR :: CSR -> GMatrix
fromCSR :: CSR -> GMatrix
fromCSR CSR
csr = SparseR :: CSR -> Int -> Int -> GMatrix
SparseR {Int
CSR
nCols :: Int
nRows :: Int
gmCSR :: CSR
nCols :: Int
nRows :: Int
gmCSR :: CSR
..}
  where
    gmCSR :: CSR
gmCSR@CSR {Int
Vector Double
Vector CInt
csrNCols :: Int
csrNRows :: Int
csrRows :: Vector CInt
csrCols :: Vector CInt
csrVals :: Vector Double
csrNCols :: CSR -> Int
csrNRows :: CSR -> Int
csrRows :: CSR -> Vector CInt
csrCols :: CSR -> Vector CInt
csrVals :: CSR -> Vector Double
..} = CSR
csr
    nRows :: Int
nRows = Int
csrNRows
    nCols :: Int
nCols = Int
csrNCols


mkDiagR :: Int -> Int -> Vector Double -> GMatrix
mkDiagR :: Int -> Int -> Vector Double -> GMatrix
mkDiagR Int
r Int
c Vector Double
v
    | Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
r Int
c = Diag :: Vector Double -> Int -> Int -> GMatrix
Diag{Int
Vector Double
diagVals :: Vector Double
nCols :: Int
nRows :: Int
diagVals :: Vector Double
nCols :: Int
nRows :: Int
..}
    | Bool
otherwise = String -> GMatrix
forall a. HasCallStack => String -> a
error (String -> GMatrix) -> String -> GMatrix
forall a b. (a -> b) -> a -> b
$ String -> Int -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"mkDiagR: incorrect sizes (%d,%d) [%d]" Int
r Int
c (Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v)
  where
    nRows :: Int
nRows = Int
r
    nCols :: Int
nCols = Int
c
    diagVals :: Vector Double
diagVals = Vector Double
v


type IV t = CInt -> Ptr CInt   -> t
type  V t = CInt -> Ptr Double -> t
type SMxV = V (IV (IV (V (V (IO CInt)))))

gmXv :: GMatrix -> Vector Double -> Vector Double
gmXv :: GMatrix -> Vector Double -> Vector Double
gmXv SparseR { gmCSR :: GMatrix -> CSR
gmCSR = CSR{Int
Vector Double
Vector CInt
csrNCols :: Int
csrNRows :: Int
csrRows :: Vector CInt
csrCols :: Vector CInt
csrVals :: Vector Double
csrNCols :: CSR -> Int
csrNRows :: CSR -> Int
csrRows :: CSR -> Vector CInt
csrCols :: CSR -> Vector CInt
csrVals :: CSR -> Vector Double
..}, Int
nCols :: Int
nRows :: Int
nCols :: GMatrix -> Int
nRows :: GMatrix -> Int
.. } Vector Double
v = IO (Vector Double) -> Vector Double
forall a. IO a -> a
unsafePerformIO (IO (Vector Double) -> Vector Double)
-> IO (Vector Double) -> Vector Double
forall a b. (a -> b) -> a -> b
$ do
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
nCols) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
      String -> IO ()
forall a. HasCallStack => String -> a
error (String -> Int -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"gmXv (CSR): incorrect sizes: (%d,%d) x %d" Int
nRows Int
nCols (Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v))

    Vector Double
r <- Int -> IO (Vector Double)
forall a. Storable a => Int -> IO (Vector a)
createVector Int
nRows
    (Vector Double
csrVals Vector Double
-> ((CInt
     -> Ptr CInt
     -> CInt
     -> Ptr CInt
     -> CInt
     -> Ptr Double
     -> CInt
     -> Ptr Double
     -> IO CInt)
    -> IO CInt)
-> Trans
     (Vector Double)
     (CInt
      -> Ptr CInt
      -> CInt
      -> Ptr CInt
      -> CInt
      -> Ptr Double
      -> CInt
      -> Ptr Double
      -> IO CInt)
-> IO CInt
forall c b r. TransArray c => c -> (b -> IO r) -> Trans c b -> IO r
# Vector CInt
csrCols Vector CInt
-> ((CInt
     -> Ptr CInt -> CInt -> Ptr Double -> CInt -> Ptr Double -> IO CInt)
    -> IO CInt)
-> Trans
     (Vector CInt)
     (CInt
      -> Ptr CInt -> CInt -> Ptr Double -> CInt -> Ptr Double -> IO CInt)
-> IO CInt
forall c b r. TransArray c => c -> (b -> IO r) -> Trans c b -> IO r
# Vector CInt
csrRows Vector CInt
-> ((CInt -> Ptr Double -> CInt -> Ptr Double -> IO CInt)
    -> IO CInt)
-> Trans
     (Vector CInt) (CInt -> Ptr Double -> CInt -> Ptr Double -> IO CInt)
-> IO CInt
forall c b r. TransArray c => c -> (b -> IO r) -> Trans c b -> IO r
# Vector Double
v Vector Double
-> Vector Double
-> Trans (Vector Double) (Trans (Vector Double) (IO CInt))
-> IO CInt
forall c c1 r.
(TransArray c, TransArray c1) =>
c1 -> c -> Trans c1 (Trans c (IO r)) -> IO r
#! Vector Double
r) Trans
  (Vector Double)
  (CInt
   -> Ptr CInt
   -> CInt
   -> Ptr CInt
   -> CInt
   -> Ptr Double
   -> CInt
   -> Ptr Double
   -> IO CInt)
SMxV
c_smXv IO CInt -> String -> IO ()
#|String
"CSRXv"
    Vector Double -> IO (Vector Double)
forall (m :: * -> *) a. Monad m => a -> m a
return Vector Double
r

gmXv SparseC { gmCSC :: GMatrix -> CSC
gmCSC = CSC{Int
Vector Double
Vector CInt
cscNCols :: Int
cscNRows :: Int
cscCols :: Vector CInt
cscRows :: Vector CInt
cscVals :: Vector Double
cscNCols :: CSC -> Int
cscNRows :: CSC -> Int
cscCols :: CSC -> Vector CInt
cscRows :: CSC -> Vector CInt
cscVals :: CSC -> Vector Double
..}, Int
nCols :: Int
nRows :: Int
nCols :: GMatrix -> Int
nRows :: GMatrix -> Int
.. } Vector Double
v = IO (Vector Double) -> Vector Double
forall a. IO a -> a
unsafePerformIO (IO (Vector Double) -> Vector Double)
-> IO (Vector Double) -> Vector Double
forall a b. (a -> b) -> a -> b
$ do
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
nCols) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
      String -> IO ()
forall a. HasCallStack => String -> a
error (String -> Int -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"gmXv (CSC): incorrect sizes: (%d,%d) x %d" Int
nRows Int
nCols (Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v))

    Vector Double
r <- Int -> IO (Vector Double)
forall a. Storable a => Int -> IO (Vector a)
createVector Int
nRows
    (Vector Double
cscVals Vector Double
-> ((CInt
     -> Ptr CInt
     -> CInt
     -> Ptr CInt
     -> CInt
     -> Ptr Double
     -> CInt
     -> Ptr Double
     -> IO CInt)
    -> IO CInt)
-> Trans
     (Vector Double)
     (CInt
      -> Ptr CInt
      -> CInt
      -> Ptr CInt
      -> CInt
      -> Ptr Double
      -> CInt
      -> Ptr Double
      -> IO CInt)
-> IO CInt
forall c b r. TransArray c => c -> (b -> IO r) -> Trans c b -> IO r
# Vector CInt
cscRows Vector CInt
-> ((CInt
     -> Ptr CInt -> CInt -> Ptr Double -> CInt -> Ptr Double -> IO CInt)
    -> IO CInt)
-> Trans
     (Vector CInt)
     (CInt
      -> Ptr CInt -> CInt -> Ptr Double -> CInt -> Ptr Double -> IO CInt)
-> IO CInt
forall c b r. TransArray c => c -> (b -> IO r) -> Trans c b -> IO r
# Vector CInt
cscCols Vector CInt
-> ((CInt -> Ptr Double -> CInt -> Ptr Double -> IO CInt)
    -> IO CInt)
-> Trans
     (Vector CInt) (CInt -> Ptr Double -> CInt -> Ptr Double -> IO CInt)
-> IO CInt
forall c b r. TransArray c => c -> (b -> IO r) -> Trans c b -> IO r
# Vector Double
v Vector Double
-> Vector Double
-> Trans (Vector Double) (Trans (Vector Double) (IO CInt))
-> IO CInt
forall c c1 r.
(TransArray c, TransArray c1) =>
c1 -> c -> Trans c1 (Trans c (IO r)) -> IO r
#! Vector Double
r) Trans
  (Vector Double)
  (CInt
   -> Ptr CInt
   -> CInt
   -> Ptr CInt
   -> CInt
   -> Ptr Double
   -> CInt
   -> Ptr Double
   -> IO CInt)
SMxV
c_smTXv IO CInt -> String -> IO ()
#|String
"CSCXv"
    Vector Double -> IO (Vector Double)
forall (m :: * -> *) a. Monad m => a -> m a
return Vector Double
r

gmXv Diag{Int
Vector Double
nCols :: Int
nRows :: Int
diagVals :: Vector Double
diagVals :: GMatrix -> Vector Double
nCols :: GMatrix -> Int
nRows :: GMatrix -> Int
..} Vector Double
v
    | Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
nCols
        = [Vector Double] -> Vector Double
forall t. Storable t => [Vector t] -> Vector t
vjoin [ Int -> Int -> Vector Double -> Vector Double
forall t. Storable t => Int -> Int -> Vector t -> Vector t
subVector Int
0 (Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
diagVals) Vector Double
v Vector Double -> Vector Double -> Vector Double
forall (c :: * -> *) e. Container c e => c e -> c e -> c e
`mul` Vector Double
diagVals
                , Double -> Int -> Vector Double
forall e d (c :: * -> *). Konst e d c => e -> d -> c e
konst Double
0 (Int
nRows Int -> Int -> Int
forall a. Num a => a -> a -> a
- Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
diagVals) ]
    | Bool
otherwise = String -> Vector Double
forall a. HasCallStack => String -> a
error (String -> Vector Double) -> String -> Vector Double
forall a b. (a -> b) -> a -> b
$ String -> Int -> Int -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"gmXv (Diag): incorrect sizes: (%d,%d) [%d] x %d"
                                 Int
nRows Int
nCols (Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
diagVals) (Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v)

gmXv Dense{Int
Matrix Double
nCols :: Int
nRows :: Int
gmDense :: Matrix Double
gmDense :: GMatrix -> Matrix Double
nCols :: GMatrix -> Int
nRows :: GMatrix -> Int
..} Vector Double
v
    | Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
nCols
        = Matrix Double -> Vector Double -> Vector Double
forall t. Product t => Matrix t -> Vector t -> Vector t
mXv Matrix Double
gmDense Vector Double
v
    | Bool
otherwise = String -> Vector Double
forall a. HasCallStack => String -> a
error (String -> Vector Double) -> String -> Vector Double
forall a b. (a -> b) -> a -> b
$ String -> Int -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"gmXv (Dense): incorrect sizes: (%d,%d) x %d"
                                 Int
nRows Int
nCols (Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v)


{- | general matrix - vector product

>>> let m = mkSparse [((0,999),1.0),((1,1999),2.0)]
m :: GMatrix
>>> m !#> vector [1..2000]
[1000.0,4000.0]
it :: Vector Double

-}
infixr 8 !#>
(!#>) :: GMatrix -> Vector Double -> Vector Double
!#> :: GMatrix -> Vector Double -> Vector Double
(!#>) = GMatrix -> Vector Double -> Vector Double
gmXv

--------------------------------------------------------------------------------

foreign import ccall unsafe "smXv"
  c_smXv :: SMxV

foreign import ccall unsafe "smTXv"
  c_smTXv :: SMxV

--------------------------------------------------------------------------------

toDense :: AssocMatrix -> Matrix Double
toDense :: AssocMatrix -> Matrix Double
toDense AssocMatrix
asm = IndexOf Matrix -> Double -> AssocMatrix -> Matrix Double
forall (c :: * -> *) e.
Container c e =>
IndexOf c -> e -> [(IndexOf c, e)] -> c e
assoc (Int
rInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1,Int
cInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Double
0 AssocMatrix
asm
  where
    (Int
r,Int
c) = ([Int] -> Int
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum ([Int] -> Int) -> ([Int] -> Int) -> ([Int], [Int]) -> (Int, Int)
forall (a :: * -> * -> *) b c b' c'.
Arrow a =>
a b c -> a b' c' -> a (b, b') (c, c')
*** [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum) (([Int], [Int]) -> (Int, Int))
-> ([((Int, Int), Double)] -> ([Int], [Int]))
-> [((Int, Int), Double)]
-> (Int, Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Int, Int)] -> ([Int], [Int])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Int, Int)] -> ([Int], [Int]))
-> ([((Int, Int), Double)] -> [(Int, Int)])
-> [((Int, Int), Double)]
-> ([Int], [Int])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (((Int, Int), Double) -> (Int, Int))
-> [((Int, Int), Double)] -> [(Int, Int)]
forall a b. (a -> b) -> [a] -> [b]
map ((Int, Int), Double) -> (Int, Int)
forall a b. (a, b) -> a
fst ([((Int, Int), Double)] -> (Int, Int))
-> [((Int, Int), Double)] -> (Int, Int)
forall a b. (a -> b) -> a -> b
$ [((Int, Int), Double)]
AssocMatrix
asm


instance Transposable CSR CSC
  where
    tr :: CSR -> CSC
tr (CSR Vector Double
vs Vector CInt
cs Vector CInt
rs Int
n Int
m) = Vector Double -> Vector CInt -> Vector CInt -> Int -> Int -> CSC
CSC Vector Double
vs Vector CInt
cs Vector CInt
rs Int
m Int
n
    tr' :: CSR -> CSC
tr' = CSR -> CSC
forall m mt. Transposable m mt => m -> mt
tr

instance Transposable CSC CSR
  where
    tr :: CSC -> CSR
tr (CSC Vector Double
vs Vector CInt
rs Vector CInt
cs Int
n Int
m) = Vector Double -> Vector CInt -> Vector CInt -> Int -> Int -> CSR
CSR Vector Double
vs Vector CInt
rs Vector CInt
cs Int
m Int
n
    tr' :: CSC -> CSR
tr' = CSC -> CSR
forall m mt. Transposable m mt => m -> mt
tr

instance Transposable GMatrix GMatrix
  where
    tr :: GMatrix -> GMatrix
tr (SparseR CSR
s Int
n Int
m) = CSC -> Int -> Int -> GMatrix
SparseC (CSR -> CSC
forall m mt. Transposable m mt => m -> mt
tr CSR
s) Int
m Int
n
    tr (SparseC CSC
s Int
n Int
m) = CSR -> Int -> Int -> GMatrix
SparseR (CSC -> CSR
forall m mt. Transposable m mt => m -> mt
tr CSC
s) Int
m Int
n
    tr (Diag Vector Double
v Int
n Int
m) = Vector Double -> Int -> Int -> GMatrix
Diag Vector Double
v Int
m Int
n
    tr (Dense Matrix Double
a Int
n Int
m) = Matrix Double -> Int -> Int -> GMatrix
Dense (Matrix Double -> Matrix Double
forall m mt. Transposable m mt => m -> mt
tr Matrix Double
a) Int
m Int
n
    tr' :: GMatrix -> GMatrix
tr' = GMatrix -> GMatrix
forall m mt. Transposable m mt => m -> mt
tr