{-# LANGUAGE PatternGuards #-}
module Statistics.Matrix
(
Matrix(..)
, Vector
, fromVector
, fromList
, fromRowLists
, fromRows
, fromColumns
, toVector
, toList
, toRows
, toColumns
, toRowLists
, generate
, generateSym
, ident
, diag
, dimension
, center
, multiply
, multiplyV
, transpose
, power
, norm
, column
, row
, map
, for
, unsafeIndex
, hasNaN
, bounds
, unsafeBounds
) where
import Prelude hiding (exponent, map)
import Control.Applicative ((<$>))
import Control.Monad.ST
import qualified Data.Vector.Unboxed as U
import Data.Vector.Unboxed ((!))
import qualified Data.Vector.Unboxed.Mutable as UM
import Numeric.Sum (sumVector,kbn)
import Statistics.Matrix.Function
import Statistics.Matrix.Types
import Statistics.Matrix.Mutable (unsafeNew,unsafeWrite,unsafeFreeze)
fromList :: Int
-> Int
-> [Double]
-> Matrix
fromList r c = fromVector r c . U.fromList
fromRowLists :: [[Double]] -> Matrix
fromRowLists = fromRows . fmap U.fromList
fromVector :: Int
-> Int
-> U.Vector Double
-> Matrix
fromVector r c v
| r*c /= len = error "input size mismatch"
| otherwise = Matrix r c v
where len = U.length v
fromRows :: [Vector] -> Matrix
fromRows xs
| [] <- xs = error "Statistics.Matrix.fromRows: empty list of rows!"
| any (/=nCol) ns = error "Statistics.Matrix.fromRows: row sizes do not match"
| nCol == 0 = error "Statistics.Matrix.fromRows: zero columns in matrix"
| otherwise = fromVector nRow nCol (U.concat xs)
where
nCol:ns = U.length <$> xs
nRow = length xs
fromColumns :: [Vector] -> Matrix
fromColumns = transpose . fromRows
toVector :: Matrix -> U.Vector Double
toVector (Matrix _ _ v) = v
toList :: Matrix -> [Double]
toList = U.toList . toVector
toRowLists :: Matrix -> [[Double]]
toRowLists (Matrix _ nCol v)
= chunks $ U.toList v
where
chunks [] = []
chunks xs = case splitAt nCol xs of
(rowE,rest) -> rowE : chunks rest
toRows :: Matrix -> [Vector]
toRows (Matrix _ nCol v) = chunks v
where
chunks xs
| U.null xs = []
| otherwise = case U.splitAt nCol xs of
(rowE,rest) -> rowE : chunks rest
toColumns :: Matrix -> [Vector]
toColumns = toRows . transpose
generate :: Int
-> Int
-> (Int -> Int -> Double)
-> Matrix
generate nRow nCol f
= Matrix nRow nCol $ U.generate (nRow*nCol) $ \i ->
let (r,c) = i `quotRem` nCol in f r c
generateSym
:: Int
-> (Int -> Int -> Double)
-> Matrix
generateSym n f = runST $ do
m <- unsafeNew n n
for 0 n $ \r -> do
unsafeWrite m r r (f r r)
for (r+1) n $ \c -> do
let x = f r c
unsafeWrite m r c x
unsafeWrite m c r x
unsafeFreeze m
ident :: Int -> Matrix
ident n = diag $ U.replicate n 1.0
diag :: Vector -> Matrix
diag v
= Matrix n n $ U.create $ do
arr <- UM.replicate (n*n) 0
for 0 n $ \i ->
UM.unsafeWrite arr (i*n + i) (v ! i)
return arr
where
n = U.length v
dimension :: Matrix -> (Int, Int)
dimension (Matrix r c _) = (r, c)
multiply :: Matrix -> Matrix -> Matrix
multiply m1@(Matrix r1 _ _) m2@(Matrix _ c2 _) =
Matrix r1 c2 $ U.generate (r1*c2) go
where
go t = sumVector kbn $ U.zipWith (*) (row m1 i) (column m2 j)
where (i,j) = t `quotRem` c2
multiplyV :: Matrix -> Vector -> Vector
multiplyV m v
| cols m == c = U.generate (rows m) (sumVector kbn . U.zipWith (*) v . row m)
| otherwise = error $ "matrix/vector unconformable " ++ show (cols m,c)
where c = U.length v
power :: Matrix -> Int -> Matrix
power mat 1 = mat
power mat n = res
where
mat2 = power mat (n `quot` 2)
pow = multiply mat2 mat2
res | odd n = multiply pow mat
| otherwise = pow
center :: Matrix -> Double
center mat@(Matrix r c _) =
unsafeBounds U.unsafeIndex mat (r `quot` 2) (c `quot` 2)
norm :: Vector -> Double
norm = sqrt . sumVector kbn . U.map square
column :: Matrix -> Int -> Vector
column (Matrix r c v) i = U.backpermute v $ U.enumFromStepN i c r
{-# INLINE column #-}
row :: Matrix -> Int -> Vector
row (Matrix _ c v) i = U.slice (c*i) c v
unsafeIndex :: Matrix
-> Int
-> Int
-> Double
unsafeIndex = unsafeBounds U.unsafeIndex
map :: (Double -> Double) -> Matrix -> Matrix
map f (Matrix r c v) = Matrix r c (U.map f v)
hasNaN :: Matrix -> Bool
hasNaN = U.any isNaN . toVector
bounds :: (Vector -> Int -> r) -> Matrix -> Int -> Int -> r
bounds k (Matrix rs cs v) r c
| r < 0 || r >= rs = error "row out of bounds"
| c < 0 || c >= cs = error "column out of bounds"
| otherwise = k v $! r * cs + c
{-# INLINE bounds #-}
unsafeBounds :: (Vector -> Int -> r) -> Matrix -> Int -> Int -> r
unsafeBounds k (Matrix _ cs v) r c = k v $! r * cs + c
{-# INLINE unsafeBounds #-}
transpose :: Matrix -> Matrix
transpose m@(Matrix r0 c0 _) = Matrix c0 r0 . U.generate (r0*c0) $ \i ->
let (r,c) = i `quotRem` r0
in unsafeIndex m c r