module Data.Matrix (
Matrix , prettyMatrix
, nrows , ncols
, forceMatrix
, matrix
, rowVector
, colVector
, zero
, identity
, diagonalList
, diagonal
, permMatrix
, fromList , fromLists
, toList , toLists
, getElem , (!) , unsafeGet , safeGet, safeSet
, getRow , safeGetRow , getCol , safeGetCol
, getDiag
, getMatrixAsVector
, setElem
, unsafeSet
, transpose , setSize , extendTo
, inverse, rref
, mapRow , mapCol, mapPos
, submatrix
, minorMatrix
, splitBlocks
, (<|>) , (<->)
, joinBlocks
, elementwise, elementwiseUnsafe
, multStd
, multStd2
, multStrassen
, multStrassenMixed
, scaleMatrix
, scaleRow
, combineRows
, switchRows
, switchCols
, luDecomp , luDecompUnsafe
, luDecomp', luDecompUnsafe'
, cholDecomp
, trace , diagProd
, detLaplace
, detLU
, flatten
) where
import Prelude hiding (foldl1)
import Control.DeepSeq
import Control.Monad (forM_)
import Control.Loop (numLoop,numLoopFold)
import Data.Foldable (Foldable, foldMap, foldl1)
import Data.Maybe
import Data.Monoid
import qualified Data.Semigroup as S
import Data.Traversable
import Control.Applicative(Applicative, (<$>), (<*>), pure)
import GHC.Generics (Generic)
import Control.Monad.Primitive (PrimMonad, PrimState)
import Data.List (maximumBy,foldl1')
import Data.Ord (comparing)
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as MV
encode :: Int -> (Int,Int) -> Int
encode m (i,j) = (i1)*m + j 1
decode :: Int -> Int -> (Int,Int)
decode m k = (q+1,r+1)
where
(q,r) = quotRem k m
data Matrix a = M {
nrows :: !Int
, ncols :: !Int
, rowOffset :: !Int
, colOffset :: !Int
, vcols :: !Int
, mvect :: V.Vector a
} deriving (Generic)
instance Eq a => Eq (Matrix a) where
m1 == m2 =
let r = nrows m1
c = ncols m1
in and $ (r == nrows m2) : (c == ncols m2)
: [ m1 ! (i,j) == m2 ! (i,j) | i <- [1 .. r] , j <- [1 .. c] ]
sizeStr :: Int -> Int -> String
sizeStr n m = show n ++ "x" ++ show m
prettyMatrix :: Show a => Matrix a -> String
prettyMatrix m = concat
[ "┌ ", unwords (replicate (ncols m) blank), " ┐\n"
, unlines
[ "│ " ++ unwords (fmap (\j -> fill $ strings ! (i,j)) [1..ncols m]) ++ " │" | i <- [1..nrows m] ]
, "└ ", unwords (replicate (ncols m) blank), " ┘"
]
where
strings@(M _ _ _ _ _ v) = fmap show m
widest = V.maximum $ fmap length v
fill str = replicate (widest length str) ' ' ++ str
blank = fill ""
instance Show a => Show (Matrix a) where
show = prettyMatrix
instance NFData a => NFData (Matrix a) where
rnf = rnf . mvect
forceMatrix :: Matrix a -> Matrix a
forceMatrix m = matrix (nrows m) (ncols m) $ \(i,j) -> unsafeGet i j m
instance Functor Matrix where
fmap f (M n m ro co w v) = M n m ro co w $ V.map f v
instance Monoid a => S.Semigroup (Matrix a) where
(<>) = mappend
instance Monoid a => Monoid (Matrix a) where
mempty = fromList 1 1 [mempty]
mappend m m' = matrix (max (nrows m) (nrows m')) (max (ncols m) (ncols m')) $ uncurry zipTogether
where zipTogether row column = fromMaybe mempty $ safeGet row column m <> safeGet row column m'
instance Applicative Matrix where
pure x = fromList 1 1 [x]
m <*> m' = flatten $ (\f -> f <$> m') <$> m
flatten:: Matrix (Matrix a) -> Matrix a
flatten m = foldl1 (<->) $ map (foldl1 (<|>) . (\i -> getRow i m)) [1..(nrows m)]
mapRow :: (Int -> a -> a)
-> Int
-> Matrix a -> Matrix a
mapRow f r m =
matrix (nrows m) (ncols m) $ \(i,j) ->
let a = unsafeGet i j m
in if i == r
then f j a
else a
mapCol :: (Int -> a -> a)
-> Int
-> Matrix a -> Matrix a
mapCol f c m =
matrix (nrows m) (ncols m) $ \(i,j) ->
let a = unsafeGet i j m
in if j == c
then f i a
else a
mapPos :: ((Int, Int) -> a -> b)
-> Matrix a
-> Matrix b
mapPos f m@(M {ncols = cols, mvect = vect})=
m { mvect = V.imap (\i e -> f (decode cols i) e) vect}
instance Foldable Matrix where
foldMap f = foldMap f . mvect . forceMatrix
instance Traversable Matrix where
sequenceA m = fmap (M (nrows m) (ncols m) 0 0 (ncols m)) . sequenceA . mvect $ forceMatrix m
zero :: Num a =>
Int
-> Int
-> Matrix a
zero n m = M n m 0 0 m $ V.replicate (n*m) 0
matrix :: Int
-> Int
-> ((Int,Int) -> a)
-> Matrix a
matrix n m f = M n m 0 0 m $ V.create $ do
v <- MV.new $ n * m
let en = encode m
numLoop 1 n $
\i -> numLoop 1 m $
\j -> MV.unsafeWrite v (en (i,j)) (f (i,j))
return v
identity :: Num a => Int -> Matrix a
identity n = matrix n n $ \(i,j) -> if i == j then 1 else 0
diagonal :: a
-> V.Vector a
-> Matrix a
diagonal e v = matrix n n $ \(i,j) -> if i == j then V.unsafeIndex v (i 1) else e
where
n = V.length v
fromList :: Int
-> Int
-> [a]
-> Matrix a
fromList n m = M n m 0 0 m . V.fromListN (n*m)
toList :: Matrix a -> [a]
toList m = [ unsafeGet i j m | i <- [1 .. nrows m] , j <- [1 .. ncols m] ]
toLists :: Matrix a -> [[a]]
toLists m = [ [ unsafeGet i j m | j <- [1 .. ncols m] ] | i <- [1 .. nrows m] ]
diagonalList :: Int -> a -> [a] -> Matrix a
diagonalList n e xs = matrix n n $ \(i,j) -> if i == j then xs !! (i 1) else e
fromLists :: [[a]] -> Matrix a
fromLists [] = error "fromLists: empty list."
fromLists (xs:xss) = fromList n m $ concat $ xs : fmap (take m) xss
where
n = 1 + length xss
m = length xs
rowVector :: V.Vector a -> Matrix a
rowVector v = M 1 m 0 0 m v
where
m = V.length v
colVector :: V.Vector a -> Matrix a
colVector v = M (V.length v) 1 0 0 1 v
permMatrix :: Num a
=> Int
-> Int
-> Int
-> Matrix a
permMatrix n r1 r2 | r1 == r2 = identity n
permMatrix n r1 r2 = matrix n n f
where
f (i,j)
| i == r1 = if j == r2 then 1 else 0
| i == r2 = if j == r1 then 1 else 0
| i == j = 1
| otherwise = 0
getElem :: Int
-> Int
-> Matrix a
-> a
getElem i j m =
fromMaybe
(error $
"getElem: Trying to get the "
++ show (i, j)
++ " element from a "
++ sizeStr (nrows m) (ncols m)
++ " matrix."
)
(safeGet i j m)
unsafeGet :: Int
-> Int
-> Matrix a
-> a
unsafeGet i j (M _ _ ro co w v) = V.unsafeIndex v $ encode w (i+ro,j+co)
(!) :: Matrix a -> (Int,Int) -> a
m ! (i,j) = getElem i j m
(!.) :: Matrix a -> (Int,Int) -> a
m !. (i,j) = unsafeGet i j m
safeGet :: Int -> Int -> Matrix a -> Maybe a
safeGet i j a@(M n m _ _ _ _)
| i > n || j > m || i < 1 || j < 1 = Nothing
| otherwise = Just $ unsafeGet i j a
safeSet:: a -> (Int, Int) -> Matrix a -> Maybe (Matrix a)
safeSet x p@(i,j) a@(M n m _ _ _ _)
| i > n || j > m || i < 1 || j < 1 = Nothing
| otherwise = Just $ unsafeSet x p a
getRow :: Int -> Matrix a -> V.Vector a
getRow i (M _ m ro co w v) = V.slice (w*(i1+ro) + co) m v
safeGetRow :: Int -> Matrix a -> Maybe (V.Vector a)
safeGetRow r m
| r > nrows m || r < 1 = Nothing
| otherwise = Just $ getRow r m
getCol :: Int -> Matrix a -> V.Vector a
getCol j (M n _ ro co w v) = V.generate n $ \i -> v V.! encode w (i+1+ro,j+co)
safeGetCol :: Int -> Matrix a -> Maybe (V.Vector a)
safeGetCol c m
| c > ncols m || c < 1 = Nothing
| otherwise = Just $ getCol c m
getDiag :: Matrix a -> V.Vector a
getDiag m = V.generate k $ \i -> m ! (i+1,i+1)
where
k = min (nrows m) (ncols m)
getMatrixAsVector :: Matrix a -> V.Vector a
getMatrixAsVector = mvect . forceMatrix
msetElem :: PrimMonad m
=> a
-> Int
-> Int
-> Int
-> (Int,Int)
-> MV.MVector (PrimState m) a
-> m ()
msetElem x w ro co (i,j) v = MV.write v (encode w (i+ro,j+co)) x
unsafeMset :: PrimMonad m
=> a
-> Int
-> Int
-> Int
-> (Int,Int)
-> MV.MVector (PrimState m) a
-> m ()
unsafeMset x w ro co (i,j) v = MV.unsafeWrite v (encode w (i+ro,j+co)) x
setElem :: a
-> (Int,Int)
-> Matrix a
-> Matrix a
setElem x p (M n m ro co w v) = M n m ro co w $ V.modify (msetElem x w ro co p) v
unsafeSet :: a
-> (Int,Int)
-> Matrix a
-> Matrix a
unsafeSet x p (M n m ro co w v) = M n m ro co w $ V.modify (unsafeMset x w ro co p) v
transpose :: Matrix a -> Matrix a
transpose m = matrix (ncols m) (nrows m) $ \(i,j) -> m ! (j,i)
inverse :: (Fractional a, Eq a) => Matrix a -> Either String (Matrix a)
inverse m
| ncols m /= nrows m
= Left
$ "Inverting non-square matrix with dimensions "
++ show (sizeStr (ncols m) (nrows m))
| otherwise =
let
adjoinedWId = m <|> identity (nrows m)
rref'd = rref adjoinedWId
in rref'd >>= return . submatrix 1 (nrows m) (ncols m + 1) (ncols m * 2)
rref :: (Fractional a, Eq a) => Matrix a -> Either String (Matrix a)
rref m
| ncols m < nrows m
= Left $
"Invalid dimensions "
++ show (sizeStr (ncols m) (nrows m))
++ "; the number of columns must be greater than or equal to the number of rows"
| otherwise = rrefRefd =<< (ref m)
where
rrefRefd mtx
| nrows mtx == 1 = Right mtx
| otherwise =
let
resolvedRight = foldr (.) id (map resolveRow [1..col1]) mtx
where
col = nrows mtx
resolveRow n = combineRows n (getElem n col mtx) col
top = submatrix 1 (nrows resolvedRight 1) 1 (ncols resolvedRight) resolvedRight
top' = rrefRefd top
bot = submatrix (nrows resolvedRight) (nrows resolvedRight) 1 (ncols resolvedRight) resolvedRight
in top' >>= return . (<-> bot)
ref :: (Fractional a, Eq a) => Matrix a -> Either String (Matrix a)
ref mtx
| nrows mtx == 1
= clearedLeft
| otherwise = do
(tl, tr, bl, br) <- (splitBlocks 1 1 <$> clearedLeft)
br' <- ref br
return ((tl <|> tr) <-> (bl <|> br'))
where
sigAtTop = (\row -> switchRows 1 row mtx) <$> goodRow
where
significantRow n = getElem n 1 mtx /= 0
goodRow = case listToMaybe (filter significantRow [1..nrows mtx]) of
Nothing -> Left "Attempt to invert a non-invertible matrix"
Just x -> return x
normalizedFirstRow = (\sigAtTop' -> scaleRow (1 / getElem 1 1 sigAtTop') 1 sigAtTop') <$> sigAtTop
clearedLeft = do
comb <- mapM combinator [2..nrows mtx]
firstRow <- normalizedFirstRow
return $ (foldr (.) id comb) firstRow
where
combinator n = (\normalizedFirstRow' ->combineRows n (getElem n 1 normalizedFirstRow') 1) <$> normalizedFirstRow
extendTo :: a
-> Int
-> Int
-> Matrix a -> Matrix a
extendTo e n m a = setSize e (max n $ nrows a) (max m $ ncols a) a
setSize :: a
-> Int
-> Int
-> Matrix a
-> Matrix a
setSize e n m a@(M n0 m0 _ _ _ _) = matrix n m $ \(i,j) ->
if i <= n0 && j <= m0
then unsafeGet i j a
else e
submatrix :: Int
-> Int
-> Int
-> Int
-> Matrix a
-> Matrix a
submatrix r1 r2 c1 c2 (M n m ro co w v)
| r1 < 1 || r1 > n = error $ "submatrix: starting row (" ++ show r1 ++ ") is out of range. Matrix has " ++ show n ++ " rows."
| c1 < 1 || c1 > m = error $ "submatrix: starting column (" ++ show c1 ++ ") is out of range. Matrix has " ++ show m ++ " columns."
| r2 < r1 || r2 > n = error $ "submatrix: ending row (" ++ show r2 ++ ") is out of range. Matrix has " ++ show n ++ " rows, and starting row is " ++ show r1 ++ "."
| c2 < c1 || c2 > m = error $ "submatrix: ending column (" ++ show c2 ++ ") is out of range. Matrix has " ++ show m ++ " columns, and starting column is " ++ show c1 ++ "."
| otherwise = M (r2r1+1) (c2c1+1) (ro+r11) (co+c11) w v
minorMatrix :: Int
-> Int
-> Matrix a
-> Matrix a
minorMatrix r0 c0 (M n m ro co w v) =
let r = r0 + ro
c = c0 + co
in M (n1) (m1) ro co (w1) $ V.ifilter (\k _ -> let (i,j) = decode w k in i /= r && j /= c) v
splitBlocks :: Int
-> Int
-> Matrix a
-> (Matrix a,Matrix a
,Matrix a,Matrix a)
splitBlocks i j a@(M n m _ _ _ _) =
( submatrix 1 i 1 j a , submatrix 1 i (j+1) m a
, submatrix (i+1) n 1 j a , submatrix (i+1) n (j+1) m a )
joinBlocks :: (Matrix a,Matrix a,Matrix a,Matrix a) -> Matrix a
joinBlocks (tl,tr,bl,br) =
let n = nrows tl
nb = nrows bl
n' = n + nb
m = ncols tl
mr = ncols tr
m' = m + mr
en = encode m'
in M n' m' 0 0 m' $ V.create $ do
v <- MV.new (n'*m')
let wr = MV.write v
numLoop 1 n $ \i -> do
numLoop 1 m $ \j -> wr (en (i ,j )) $ tl ! (i,j)
numLoop 1 mr $ \j -> wr (en (i ,j+m)) $ tr ! (i,j)
numLoop 1 nb $ \i -> do
let i' = i+n
numLoop 1 m $ \j -> wr (en (i',j )) $ bl ! (i,j)
numLoop 1 mr $ \j -> wr (en (i',j+m)) $ br ! (i,j)
return v
(<|>) :: Matrix a -> Matrix a -> Matrix a
m <|> m' =
let c = ncols m
in matrix (nrows m) (c + ncols m') $ \(i,j) ->
if j <= c then m ! (i,j) else m' ! (i,jc)
(<->) :: Matrix a -> Matrix a -> Matrix a
m <-> m' =
let r = nrows m
in matrix (r + nrows m') (ncols m) $ \(i,j) ->
if i <= r then m ! (i,j) else m' ! (ir,j)
elementwise :: (a -> b -> c) -> (Matrix a -> Matrix b -> Matrix c)
elementwise f m m' = matrix (nrows m) (ncols m) $
\k -> f (m ! k) (m' ! k)
elementwiseUnsafe :: (a -> b -> c) -> (Matrix a -> Matrix b -> Matrix c)
elementwiseUnsafe f m m' = matrix (nrows m) (ncols m) $
\(i,j) -> f (unsafeGet i j m) (unsafeGet i j m')
infixl 6 +., -.
(+.) :: Num a => Matrix a -> Matrix a -> Matrix a
(+.) = elementwiseUnsafe (+)
(-.) :: Num a => Matrix a -> Matrix a -> Matrix a
(-.) = elementwiseUnsafe ()
multStd :: Num a => Matrix a -> Matrix a -> Matrix a
multStd a1@(M n m _ _ _ _) a2@(M n' m' _ _ _ _)
| m /= n' = error $ "Multiplication of " ++ sizeStr n m ++ " and "
++ sizeStr n' m' ++ " matrices."
| otherwise = multStd_ a1 a2
multStd2 :: Num a => Matrix a -> Matrix a -> Matrix a
multStd2 a1@(M n m _ _ _ _) a2@(M n' m' _ _ _ _)
| m /= n' = error $ "Multiplication of " ++ sizeStr n m ++ " and "
++ sizeStr n' m' ++ " matrices."
| otherwise = multStd__ a1 a2
multStd_ :: Num a => Matrix a -> Matrix a -> Matrix a
multStd_ a@(M 1 1 _ _ _ _) b@(M 1 1 _ _ _ _) = M 1 1 0 0 1 $ V.singleton $ (a ! (1,1)) * (b ! (1,1))
multStd_ a@(M 2 2 _ _ _ _) b@(M 2 2 _ _ _ _) =
M 2 2 0 0 2 $
let
a11 = a !. (1,1) ; a12 = a !. (1,2)
a21 = a !. (2,1) ; a22 = a !. (2,2)
b11 = b !. (1,1) ; b12 = b !. (1,2)
b21 = b !. (2,1) ; b22 = b !. (2,2)
in V.fromList
[ a11*b11 + a12*b21 , a11*b12 + a12*b22
, a21*b11 + a22*b21 , a21*b12 + a22*b22
]
multStd_ a@(M 3 3 _ _ _ _) b@(M 3 3 _ _ _ _) =
M 3 3 0 0 3 $
let
a11 = a !. (1,1) ; a12 = a !. (1,2) ; a13 = a !. (1,3)
a21 = a !. (2,1) ; a22 = a !. (2,2) ; a23 = a !. (2,3)
a31 = a !. (3,1) ; a32 = a !. (3,2) ; a33 = a !. (3,3)
b11 = b !. (1,1) ; b12 = b !. (1,2) ; b13 = b !. (1,3)
b21 = b !. (2,1) ; b22 = b !. (2,2) ; b23 = b !. (2,3)
b31 = b !. (3,1) ; b32 = b !. (3,2) ; b33 = b !. (3,3)
in V.fromList
[ a11*b11 + a12*b21 + a13*b31 , a11*b12 + a12*b22 + a13*b32 , a11*b13 + a12*b23 + a13*b33
, a21*b11 + a22*b21 + a23*b31 , a21*b12 + a22*b22 + a23*b32 , a21*b13 + a22*b23 + a23*b33
, a31*b11 + a32*b21 + a33*b31 , a31*b12 + a32*b22 + a33*b32 , a31*b13 + a32*b23 + a33*b33
]
multStd_ a@(M n m _ _ _ _) b@(M _ m' _ _ _ _) = matrix n m' $ \(i,j) -> sum [ a !. (i,k) * b !. (k,j) | k <- [1 .. m] ]
multStd__ :: Num a => Matrix a -> Matrix a -> Matrix a
multStd__ a b = matrix r c $ \(i,j) -> dotProduct (V.unsafeIndex avs $ i 1) (V.unsafeIndex bvs $ j 1)
where
r = nrows a
avs = V.generate r $ \i -> getRow (i+1) a
c = ncols b
bvs = V.generate c $ \i -> getCol (i+1) b
dotProduct :: Num a => V.Vector a -> V.Vector a -> a
dotProduct v1 v2 = numLoopFold 0 (V.length v1 1) 0 $
\r i -> V.unsafeIndex v1 i * V.unsafeIndex v2 i + r
first :: (a -> Bool) -> [a] -> a
first f = go
where
go (x:xs) = if f x then x else go xs
go _ = error "first: no element match the condition."
strassen :: Num a => Matrix a -> Matrix a -> Matrix a
strassen a@(M 1 1 _ _ _ _) b@(M 1 1 _ _ _ _) = M 1 1 0 0 1 $ V.singleton $ (a ! (1,1)) * (b ! (1,1))
strassen a b = joinBlocks (c11,c12,c21,c22)
where
n = div (nrows a) 2
(a11,a12,a21,a22) = splitBlocks n n a
(b11,b12,b21,b22) = splitBlocks n n b
p1 = strassen (a11 + a22) (b11 + b22)
p2 = strassen (a21 + a22) b11
p3 = strassen a11 (b12 b22)
p4 = strassen a22 (b21 b11)
p5 = strassen (a11 + a12) b22
p6 = strassen (a21 a11) (b11 + b12)
p7 = strassen (a12 a22) (b21 + b22)
c11 = p1 + p4 p5 + p7
c12 = p3 + p5
c21 = p2 + p4
c22 = p1 p2 + p3 + p6
multStrassen :: Num a => Matrix a -> Matrix a -> Matrix a
multStrassen a1@(M n m _ _ _ _) a2@(M n' m' _ _ _ _)
| m /= n' = error $ "Multiplication of " ++ sizeStr n m ++ " and "
++ sizeStr n' m' ++ " matrices."
| otherwise =
let mx = maximum [n,m,n',m']
n2 = first (>= mx) $ fmap (2^) [(0 :: Int)..]
b1 = setSize 0 n2 n2 a1
b2 = setSize 0 n2 n2 a2
in submatrix 1 n 1 m' $ strassen b1 b2
strmixFactor :: Int
strmixFactor = 300
strassenMixed :: Num a => Matrix a -> Matrix a -> Matrix a
strassenMixed a b
| r < strmixFactor = multStd__ a b
| odd r = let r' = r + 1
a' = setSize 0 r' r' a
b' = setSize 0 r' r' b
in submatrix 1 r 1 r $ strassenMixed a' b'
| otherwise =
M r r 0 0 r $ V.create $ do
v <- MV.unsafeNew (r*r)
let en = encode r
n' = n + 1
sequence_ [ MV.write v k $
unsafeGet i j p1
+ unsafeGet i j p4
unsafeGet i j p5
+ unsafeGet i j p7
| i <- [1..n]
, j <- [1..n]
, let k = en (i,j)
]
sequence_ [ MV.write v k $
unsafeGet i j' p3
+ unsafeGet i j' p5
| i <- [1..n]
, j <- [n'..r]
, let k = en (i,j)
, let j' = j n
]
sequence_ [ MV.write v k $
unsafeGet i' j p2
+ unsafeGet i' j p4
| i <- [n'..r]
, j <- [1..n]
, let k = en (i,j)
, let i' = i n
]
sequence_ [ MV.write v k $
unsafeGet i' j' p1
unsafeGet i' j' p2
+ unsafeGet i' j' p3
+ unsafeGet i' j' p6
| i <- [n'..r]
, j <- [n'..r]
, let k = en (i,j)
, let i' = i n
, let j' = j n
]
return v
where
r = nrows a
n = quot r 2
(a11,a12,a21,a22) = splitBlocks n n a
(b11,b12,b21,b22) = splitBlocks n n b
p1 = strassenMixed (a11 +. a22) (b11 +. b22)
p2 = strassenMixed (a21 +. a22) b11
p3 = strassenMixed a11 (b12 -. b22)
p4 = strassenMixed a22 (b21 -. b11)
p5 = strassenMixed (a11 +. a12) b22
p6 = strassenMixed (a21 -. a11) (b11 +. b12)
p7 = strassenMixed (a12 -. a22) (b21 +. b22)
multStrassenMixed :: Num a => Matrix a -> Matrix a -> Matrix a
multStrassenMixed a1@(M n m _ _ _ _) a2@(M n' m' _ _ _ _)
| m /= n' = error $ "Multiplication of " ++ sizeStr n m ++ " and "
++ sizeStr n' m' ++ " matrices."
| n < strmixFactor = multStd__ a1 a2
| otherwise =
let mx = maximum [n,m,n',m']
n2 = if even mx then mx else mx+1
b1 = setSize 0 n2 n2 a1
b2 = setSize 0 n2 n2 a2
in submatrix 1 n 1 m' $ strassenMixed b1 b2
instance Num a => Num (Matrix a) where
fromInteger = M 1 1 0 0 1 . V.singleton . fromInteger
negate = fmap negate
abs = fmap abs
signum = fmap signum
(+) = elementwise (+)
() = elementwise ()
(*) = multStrassenMixed
scaleMatrix :: Num a => a -> Matrix a -> Matrix a
scaleMatrix = fmap . (*)
scaleRow :: Num a => a -> Int -> Matrix a -> Matrix a
scaleRow = mapRow . const . (*)
combineRows :: Num a => Int -> a -> Int -> Matrix a -> Matrix a
combineRows r1 l r2 m = mapRow (\j x -> x + l * getElem r2 j m) r1 m
switchRows :: Int
-> Int
-> Matrix a
-> Matrix a
switchRows r1 r2 (M n m ro co w vs) = M n m ro co w $ V.modify (\mv -> do
numLoop 1 m $ \j ->
MV.swap mv (encode w (r1+ro,j+co)) (encode w (r2+ro,j+co))) vs
switchCols :: Int
-> Int
-> Matrix a
-> Matrix a
switchCols c1 c2 (M n m ro co w vs) = M n m ro co w $ V.modify (\mv -> do
numLoop 1 n $ \j ->
MV.swap mv (encode m (j+ro,c1+co)) (encode m (j+ro,c2+co))) vs
luDecomp :: (Ord a, Fractional a) => Matrix a -> Maybe (Matrix a,Matrix a,Matrix a,a)
luDecomp a = recLUDecomp a i i 1 1 n
where
i = identity $ nrows a
n = min (nrows a) (ncols a)
recLUDecomp :: (Ord a, Fractional a)
=> Matrix a
-> Matrix a
-> Matrix a
-> a
-> Int
-> Int
-> Maybe (Matrix a,Matrix a,Matrix a,a)
recLUDecomp u l p d k n =
if k > n then Just (u,l,p,d)
else if ukk == 0 then Nothing
else recLUDecomp u'' l'' p' d' (k+1) n
where
i = maximumBy (\x y -> compare (abs $ u ! (x,k)) (abs $ u ! (y,k))) [ k .. n ]
u' = switchRows k i u
l' = let lw = vcols l
en = encode lw
lro = rowOffset l
lco = colOffset l
in if i == k
then l
else M (nrows l) (ncols l) lro lco lw $
V.modify (\mv -> forM_ [1 .. k1] $
\j -> MV.swap mv (en (i+lro,j+lco))
(en (k+lro,j+lco))
) $ mvect l
p' = switchRows k i p
d' = if i == k then d else negate d
(u'',l'') = go u' l' (k+1)
ukk = u' ! (k,k)
go u_ l_ j =
if j > nrows u_
then (u_,l_)
else let x = (u_ ! (j,k)) / ukk
in go (combineRows j (x) k u_) (setElem x (j,k) l_) (j+1)
luDecompUnsafe :: (Ord a, Fractional a) => Matrix a -> (Matrix a, Matrix a, Matrix a, a)
luDecompUnsafe m = case luDecomp m of
Just x -> x
_ -> error "luDecompUnsafe of singular matrix."
luDecomp' :: (Ord a, Fractional a) => Matrix a -> Maybe (Matrix a,Matrix a,Matrix a,Matrix a,a,a)
luDecomp' a = recLUDecomp' a i i (identity $ ncols a) 1 1 1 n
where
i = identity $ nrows a
n = min (nrows a) (ncols a)
luDecompUnsafe' :: (Ord a, Fractional a) => Matrix a -> (Matrix a, Matrix a, Matrix a, Matrix a, a, a)
luDecompUnsafe' m = case luDecomp' m of
Just x -> x
_ -> error "luDecompUnsafe' of singular matrix."
recLUDecomp' :: (Ord a, Fractional a)
=> Matrix a
-> Matrix a
-> Matrix a
-> Matrix a
-> a
-> a
-> Int
-> Int
-> Maybe (Matrix a,Matrix a,Matrix a,Matrix a,a,a)
recLUDecomp' u l p q d e k n =
if k > n || u'' ! (k, k) == 0
then Just (u,l,p,q,d,e)
else if ukk == 0
then Nothing
else recLUDecomp' u'' l'' p' q' d' e' (k+1) n
where
(i, j) = maximumBy (comparing (\(i0, j0) -> abs $ u ! (i0,j0)))
[ (i0, j0) | i0 <- [k .. nrows u], j0 <- [k .. ncols u] ]
u' = switchCols k j $ switchRows k i u
l'0 = switchRows k i l
l' = switchCols k i l'0
p' = switchRows k i p
q' = switchCols k j q
d' = if i == k then d else negate d
e' = if j == k then e else negate e
(u'',l'') = go u' l' (k+1)
ukk = u' ! (k,k)
go u_ l_ h =
if h > nrows u_
then (u_,l_)
else let x = (u_ ! (h,k)) / ukk
in go (combineRows h (x) k u_) (setElem x (h,k) l_) (h+1)
cholDecomp :: (Floating a) => Matrix a -> Matrix a
cholDecomp a
| (nrows a == 1) && (ncols a == 1) = fmap sqrt a
| otherwise = joinBlocks (l11, l12, l21, l22) where
(a11, a12, a21, a22) = splitBlocks 1 1 a
l11' = sqrt (a11 ! (1,1))
l11 = fromList 1 1 [l11']
l12 = zero (nrows a12) (ncols a12)
l21 = scaleMatrix (1/l11') a21
a22' = a22 multStd l21 (transpose l21)
l22 = cholDecomp a22'
trace :: Num a => Matrix a -> a
trace = V.sum . getDiag
diagProd :: Num a => Matrix a -> a
diagProd = V.product . getDiag
detLaplace :: Num a => Matrix a -> a
detLaplace m@(M 1 1 _ _ _ _) = m ! (1,1)
detLaplace m = sum1 [ (1)^(i1) * m ! (i,1) * detLaplace (minorMatrix i 1 m) | i <- [1 .. nrows m] ]
where
sum1 = foldl1' (+)
detLU :: (Ord a, Fractional a) => Matrix a -> a
detLU m = case luDecomp m of
Just (u,_,_,d) -> d * diagProd u
Nothing -> 0