{-# LANGUAGE PatternGuards #-}
module Statistics.Matrix
(
Matrix(..)
, Vector
, fromVector
, dimension
, multiplyV
, transpose
, norm
, column
, for
, unsafeIndex
) where
import Prelude hiding (exponent, map, sum)
import qualified Data.Vector.Unboxed as U
import Statistics.Function (for, square)
import Statistics.Matrix.Types
import Statistics.Sample.Internal (sum)
fromVector :: Int
-> Int
-> U.Vector Double
-> Matrix
fromVector r c v
| r*c /= len = error "input size mismatch"
| otherwise = Matrix r c 0 v
where len = U.length v
dimension :: Matrix -> (Int, Int)
dimension (Matrix r c _ _) = (r, c)
multiplyV :: Matrix -> Vector -> Vector
multiplyV m v
| cols m == c = U.generate (rows m) (sum . U.zipWith (*) v . row m)
| otherwise = error $ "matrix/vector unconformable " ++ show (cols m,c)
where c = U.length v
norm :: Vector -> Double
norm = sqrt . sum . U.map square
column :: Matrix -> Int -> Vector
column (Matrix r c _ v) i = U.backpermute v $ U.enumFromStepN i c r
{-# INLINE column #-}
row :: Matrix -> Int -> Vector
row (Matrix _ c _ v) i = U.slice (c*i) c v
unsafeIndex :: Matrix
-> Int
-> Int
-> Double
unsafeIndex = unsafeBounds U.unsafeIndex
unsafeBounds :: (Vector -> Int -> r) -> Matrix -> Int -> Int -> r
unsafeBounds k (Matrix _ cs _ v) r c = k v $! r * cs + c
{-# INLINE unsafeBounds #-}
transpose :: Matrix -> Matrix
transpose m@(Matrix r0 c0 e _) = Matrix c0 r0 e . U.generate (r0*c0) $ \i ->
let (r,c) = i `quotRem` r0
in unsafeIndex m c r