{-# LANGUAGE Haskell2010
, TypeFamilies
, FlexibleContexts
, Trustworthy
, StandaloneDeriving
, DeriveDataTypeable
, ConstrainedClassMethods
#-}
{-# OPTIONS_GHC -Wall -fno-warn-name-shadowing #-}
module Numeric.Matrix (
Matrix,
MatrixElement (..),
(<|>),
(<->),
scale,
isUnit,
isZero,
isDiagonal,
isEmpty,
isSquare,
toDoubleMatrix,
toComplexMatrix,
toRationalMatrix
) where
import Control.Applicative ((<$>))
import Control.DeepSeq
import Control.Monad
import Control.Monad.ST
import Data.Function (on)
import Data.Ratio
import Data.Complex
import Data.Maybe
import Data.Int
import Data.Word
import qualified Data.List as L
import Data.Array.IArray
import Data.Array.MArray
import Data.Array.Unboxed
import Data.Array.ST
import Data.STRef
import Data.Binary
import qualified Data.Array.Unsafe as U
import Data.Typeable
import Prelude (Show, Read, Num, Fractional, Eq, Bool (..), Integer, Integral,
Float, Double, RealFloat, Ord, Real,
(*), (/), (+), (-), (^), (.), (>=), (==), (/=), ($), (>), (!!),
(&&), (||),
undefined, null, head, zip, abs, flip, length, compare, drop,
negate, not, filter, fromIntegral, fst, snd, foldl1, min, max,
error, fromInteger, signum, lines, words, show, unwords,
unlines,
otherwise, id, const, uncurry, quot, toRational, fromRational)
import qualified Prelude as P
import Data.Monoid
data family Matrix e
#if defined(__GLASGOW_HASKELL__) && (__GLASGOW_HASKELL__ >= 707)
deriving instance Typeable Matrix
#else
deriving instance Typeable1 Matrix
#endif
data instance Matrix Int
= IntMatrix !Int !Int (Array Int (UArray Int Int))
data instance Matrix Float
= FloatMatrix !Int !Int (Array Int (UArray Int Float))
data instance Matrix Double
= DoubleMatrix !Int !Int (Array Int (UArray Int Double))
data instance Matrix Integer
= IntegerMatrix !Int !Int (Array Int (Array Int Integer))
data instance Matrix (Ratio a)
= RatioMatrix !Int !Int (Array Int (Array Int (Ratio a)))
data instance Matrix (Complex a)
= ComplexMatrix !Int !Int (Array Int (Array Int (Complex a)))
instance (MatrixElement e, Show e) => Show (Matrix e) where
show = unlines . P.map showRow . toList
where
showRow = unwords . P.map ((' ':) . show)
instance (Read e, MatrixElement e) => Read (Matrix e) where
readsPrec _ = (\x -> [(x, "")]) . fromList . P.map (P.map P.read . words) . lines
instance (MatrixElement e) => Num (Matrix e) where
(+) = plus
(-) = minus
(*) = times
abs = map abs
signum = matrix (1,1) . const . signum . det
fromInteger = matrix (1,1) . const . fromInteger
instance (MatrixElement e, Fractional e) => Fractional (Matrix e) where
recip = fromJust . inv
fromRational = matrix (1,1) . const . fromRational
instance (MatrixElement e) => Eq (Matrix e) where
m == n
| dimensions m == dimensions n
= allWithIndex (\ix e -> m `at` ix == e) n
| otherwise = False
instance (MatrixElement e) => NFData (Matrix e) where
rnf matrix = matrix `deepseq` ()
instance (MatrixElement e, Binary e) => Binary (Matrix e) where
put m = do
let (rows, cols) = dimensions m
put rows >> put cols
forM_ [1..rows] $ \i -> do
forM_ [1..cols] $ \j -> do
put (m `at` (i,j))
get = do
rows <- get :: Get Int
cols <- get :: Get Int
forM [1..rows] (const (forM [1..cols] (const get)))
>>= return . fromList
(<|>) :: MatrixElement e => Matrix e -> Matrix e -> Matrix e
m1 <|> m2 = let m = numCols m1
n1 = numRows m1
n2 = numRows m2
in matrix (max n1 n2, m + numCols m2)
$ \(i,j) -> if j > m
then (if i > n2 then 0 else m2 `at` (i,j-m))
else (if i > n1 then 0 else m1 `at` (i,j))
(<->) :: MatrixElement e => Matrix e -> Matrix e -> Matrix e
m1 <-> m2 = let m = numRows m1
n1 = numCols m1
n2 = numCols m2
in matrix (m + numRows m2, max n1 n2)
$ \(i,j) -> if i > m
then (if j > n2 then 0 else m2 `at` (i-m,j))
else (if j > n1 then 0 else m1 `at` (i,j))
scale :: MatrixElement e => Matrix e -> e -> Matrix e
scale m s = map (*s) m
isUnit, isDiagonal, isZero, isEmpty, isSquare :: MatrixElement e => Matrix e -> Bool
isZero = all (== 0)
isUnit m = isSquare m && allWithIndex (uncurry check) m
where check = \i j e -> if i == j then e == 1 else e == 0
isEmpty m = numRows m == 0 || numCols m == 0
isDiagonal m = isSquare m && allWithIndex (uncurry check) m
where check = \i j e -> if i /= j then e == 0 else True
isSquare m = let (a, b) = dimensions m in a == b
toDoubleMatrix :: (MatrixElement a, Integral a) => Matrix a -> Matrix Double
toDoubleMatrix = map fromIntegral
toRationalMatrix :: (MatrixElement a, Real a) => Matrix a -> Matrix Rational
toRationalMatrix = map toRational
toComplexMatrix :: (MatrixElement a, RealFloat a, Show a) => Matrix a -> Matrix (Complex a)
toComplexMatrix = map (:+ 0)
class Division e where
divide :: e -> e -> e
instance Division Int where divide = quot
instance Division Int8 where divide = quot
instance Division Int16 where divide = quot
instance Division Int32 where divide = quot
instance Division Int64 where divide = quot
instance Division Word8 where divide = quot
instance Division Word16 where divide = quot
instance Division Word32 where divide = quot
instance Division Word64 where divide = quot
instance Division Integer where divide = quot
instance Division Float where divide = (/)
instance Division Double where divide = (/)
instance Integral a => Division (Ratio a) where divide = (/)
instance RealFloat a => Division (Complex a) where divide = (/)
class (Eq e, Num e) => MatrixElement e where
matrix :: (Int, Int) -> ((Int, Int) -> e) -> Matrix e
select :: ((Int, Int) -> Bool) -> Matrix e -> [e]
at :: Matrix e -> (Int, Int) -> e
row :: Int -> Matrix e -> [e]
col :: Int -> Matrix e -> [e]
dimensions :: Matrix e -> (Int, Int)
numRows :: Matrix e -> Int
numCols :: Matrix e -> Int
fromList :: [[e]] -> Matrix e
toList :: Matrix e -> [[e]]
unit :: Int -> Matrix e
zero :: Int -> Matrix e
diag :: [e] -> Matrix e
empty :: Matrix e
minus :: Matrix e -> Matrix e -> Matrix e
plus :: Matrix e -> Matrix e -> Matrix e
times :: Matrix e -> Matrix e -> Matrix e
inv :: Matrix e -> Maybe (Matrix e)
det :: Matrix e -> e
transpose :: Matrix e -> Matrix e
rank :: Matrix e -> e
trace :: Matrix e -> [e]
minor :: MatrixElement e => (Int, Int) -> Matrix e -> e
minorMatrix :: MatrixElement e => (Int, Int) -> Matrix e -> Matrix e
cofactors :: MatrixElement e => Matrix e -> Matrix e
adjugate :: MatrixElement e => Matrix e -> Matrix e
map :: MatrixElement f => (e -> f) -> Matrix e -> Matrix f
all :: (e -> Bool) -> Matrix e -> Bool
any :: (e -> Bool) -> Matrix e -> Bool
sum :: Matrix e -> e
foldMap :: Monoid m => (e -> m) -> Matrix e -> m
mapWithIndex :: MatrixElement f => ((Int, Int) -> e -> f) -> Matrix e -> Matrix f
allWithIndex :: ((Int, Int) -> e -> Bool) -> Matrix e -> Bool
anyWithIndex :: ((Int, Int) -> e -> Bool) -> Matrix e -> Bool
foldMapWithIndex :: Monoid m => ((Int, Int) -> e -> m) -> Matrix e -> m
unit n = fromList [[ if i == j then 1 else 0 | j <- [1..n]] | i <- [1..n] ]
zero n = matrix (n,n) (const 0)
empty = fromList []
diag xs = matrix (n,n) (\(i,j) -> if i == j then xs !! (i-1) else 0)
where n = length xs
select p m = [ at m (i,j) | i <- [1..numRows m]
, j <- [1..numCols m]
, p (i,j) ]
at mat (i, j) = ((!! j) . (!! i) . toList) mat
row i = (!! (i-1)) . toList
col i = row i . transpose
numRows = fst . dimensions
numCols = snd . dimensions
dimensions m = case toList m of [] -> (0, 0)
(x:xs) -> (length xs + 1, length x)
adjugate = transpose . cofactors
transpose mat = matrix (n, m) (\(i,j) -> mat `at` (j,i))
where (m, n) = dimensions mat
trace = select (uncurry (==))
inv _ = Nothing
minorMatrix (i,j) mat = matrix (numRows mat - 1, numCols mat - 1) $
\(i',j') -> mat `at` (if i' >= i then i' + 1 else i',
if j' >= j then j' + 1 else j')
minor ix = det . minorMatrix ix
cofactors mat = matrix (dimensions mat) $
\(i,j) -> fromIntegral ((-1 :: Int)^(i+j)) * minor (i,j) mat
map f = mapWithIndex (const f)
all f = allWithIndex (const f)
any f = anyWithIndex (const f)
sum = getSum . foldMap Sum
foldMap f = foldMapWithIndex (const f)
mapWithIndex f m = matrix (dimensions m) (\x -> f x (m `at` x))
allWithIndex f m = P.all id [ f (i, j) (m `at` (i,j))
| i <- [1..numRows m], j <- [1..numCols m]]
anyWithIndex f m = P.any id [ f (i, j) (m `at` (i,j))
| i <- [1..numRows m], j <- [1..numCols m]]
foldMapWithIndex f m = mconcat [ f (i, j) (m `at` (i,j))
| i <- [1..numRows m], j <- [1..numCols m]]
a `plus` b
| dimensions a /= dimensions b = error "Matrix.plus: dimensions don't match."
| otherwise = matrix (dimensions a) (\x -> a `at` x + b `at` x)
a `minus` b
| dimensions a /= dimensions b = error "Matrix.minus: dimensions don't match."
| otherwise = matrix (dimensions a) (\x -> a `at` x - b `at` x)
a `times` b
| numCols a /= numRows b = error "Matrix.times: `numRows a' and `numCols b' don't match."
| otherwise = _mult a b
instance MatrixElement Int where
matrix d g = runST (_matrix IntMatrix arrayST arraySTU d g)
fromList = _fromList IntMatrix
at (IntMatrix _ _ arr) = _at arr
dimensions (IntMatrix m n _) = (m, n)
row i (IntMatrix _ _ arr) = _row i arr
col j (IntMatrix _ _ arr) = _col j arr
toList (IntMatrix _ _ arr) = _toList arr
det (IntMatrix m n arr) = if m /= n then 0 else runST (_det thawsUnboxed arr)
rank (IntMatrix _ _ arr) = runST (_rank thawsBoxed arr)
instance MatrixElement Integer where
matrix d g = runST (_matrix IntegerMatrix arrayST arrayST d g)
fromList = _fromList IntegerMatrix
at (IntegerMatrix _ _ arr) = _at arr
dimensions (IntegerMatrix m n _) = (m, n)
row i (IntegerMatrix _ _ arr) = _row i arr
col j (IntegerMatrix _ _ arr) = _col j arr
toList (IntegerMatrix _ _ arr) = _toList arr
det (IntegerMatrix m n arr) = if m /= n then 0 else runST (_det thawsBoxed arr)
rank (IntegerMatrix _ _ arr) = runST (_rank thawsBoxed arr)
instance MatrixElement Float where
matrix d g = runST (_matrix FloatMatrix arrayST arraySTU d g)
fromList = _fromList FloatMatrix
at (FloatMatrix _ _ arr) = _at arr
dimensions (FloatMatrix m n _ ) = (m, n)
row i (FloatMatrix _ _ arr) = _row i arr
col j (FloatMatrix _ _ arr) = _col j arr
toList (FloatMatrix _ _ arr) = _toList arr
det (FloatMatrix m n arr) = if m /= n then 0 else runST (_det thawsUnboxed arr)
rank (FloatMatrix _ _ arr) = runST (_rank thawsBoxed arr)
inv (FloatMatrix m n arr) = if m /= n then Nothing else
let x = runST (_inv unboxedST pivotMax arr)
in maybe Nothing (Just . FloatMatrix m n) x
instance MatrixElement Double where
matrix d g = runST (_matrix DoubleMatrix arrayST arraySTU d g)
fromList = _fromList DoubleMatrix
at (DoubleMatrix _ _ arr) = _at arr
dimensions (DoubleMatrix m n _ ) = (m, n)
row i (DoubleMatrix _ _ arr) = _row i arr
col j (DoubleMatrix _ _ arr) = _col j arr
toList (DoubleMatrix _ _ arr) = _toList arr
det (DoubleMatrix m n arr) = if m /= n then 0 else runST (_det thawsUnboxed arr)
rank (DoubleMatrix _ _ arr) = runST (_rank thawsBoxed arr)
inv (DoubleMatrix m n arr) = if m /= n then Nothing else
let x = runST (_inv unboxedST pivotMax arr)
in maybe Nothing (Just . DoubleMatrix m n) x
instance (Show a, Integral a) => MatrixElement (Ratio a) where
matrix d g = runST (_matrix RatioMatrix arrayST arrayST d g)
fromList = _fromList RatioMatrix
at (RatioMatrix _ _ arr) = _at arr
dimensions (RatioMatrix m n _ ) = (m, n)
row i (RatioMatrix _ _ arr) = _row i arr
col j (RatioMatrix _ _ arr) = _col j arr
toList (RatioMatrix _ _ arr) = _toList arr
det (RatioMatrix m n arr) = if m /= n then 0 else runST (_det thawsBoxed arr)
rank (RatioMatrix _ _ arr) = runST (_rank thawsBoxed arr)
inv (RatioMatrix m n arr) = if m /= n then Nothing else
let x = runST (_inv boxedST pivotMax arr)
in maybe Nothing (Just . RatioMatrix m n) x
instance (Show a, RealFloat a) => MatrixElement (Complex a) where
matrix d g = runST (_matrix ComplexMatrix arrayST arrayST d g)
fromList = _fromList ComplexMatrix
at (ComplexMatrix _ _ arr) = _at arr
dimensions (ComplexMatrix m n _ ) = (m, n)
row i (ComplexMatrix _ _ arr) = _row i arr
col j (ComplexMatrix _ _ arr) = _col j arr
toList (ComplexMatrix _ _ arr) = _toList arr
det (ComplexMatrix m n arr) = if m /= n then 0 else runST (_det thawsBoxed arr)
rank (ComplexMatrix _ _ arr) = runST (_rank thawsBoxed arr)
inv (ComplexMatrix m n arr) = if m /= n then Nothing else
let x = runST (_inv boxedST pivotNonZero arr)
in maybe Nothing (Just . ComplexMatrix m n) x
_at :: (IArray a (u Int e), IArray u e)
=> a Int (u Int e) -> (Int, Int) -> e
_at arr (i,j) = arr ! i ! j
_row, _col :: (IArray a (u Int e), IArray u e) => Int -> a Int (u Int e) -> [e]
_row i arr = let row = arr ! i in [ row ! j | j <- [1..(snd (bounds row))] ]
_col j arr = [ arr ! i ! j | i <- [1..(snd (bounds arr))] ]
_toList :: (IArray a e) => Array Int (a Int e) -> [[e]]
_toList = P.map elems . elems
_fromList :: (IArray a (u Int e), IArray u e)
=> (Int -> Int -> a Int (u Int e) -> matrix e) -> [[e]] -> matrix e
_fromList c xs =
let lengths = P.map length xs
numCols = if null lengths then 0 else foldl1 min lengths
numRows = length lengths
in c numRows numCols
$ array (1, numRows)
$ zip [1..numRows]
$ P.map (array (1, numCols) . zip [1..numCols]) xs
thawsBoxed :: (IArray a e, MArray (STArray s) e (ST s))
=> Array Int (a Int e)
-> ST s [STArray s Int e]
thawsBoxed = mapM thaw . elems
thawsUnboxed :: (IArray a e, MArray (STUArray s) e (ST s))
=> Array Int (a Int e)
-> ST s [STUArray s Int e]
thawsUnboxed = mapM thaw . elems
arrays :: [(u s) Int e]
-> ST s ((STArray s) Int ((u s) Int e))
arrays list = newListArray (1, length list) list
augment :: (IArray a e, MArray (u s) e (ST s), Num e)
=> ((Int, Int) -> [e] -> ST s ((u s) Int e))
-> Array Int (a Int e)
-> ST s (STArray s Int (u s Int e))
augment _ arr = do
let (_, n) = bounds arr
row (a,i) = newListArray (1, 2*n)
[ if j > n then (if j == i + n then 1 else 0)
else a ! j
| j <- [1..2*n] ]
mapM row (zip (elems arr) [1..]) >>= newListArray (1, n)
boxedST :: MArray (STArray s) e (ST s)
=> (Int, Int) -> [e] -> ST s ((STArray s) Int e)
boxedST = newListArray
unboxedST :: MArray (STUArray s) e (ST s)
=> (Int, Int) -> [e] -> ST s ((STUArray s) Int e)
unboxedST = newListArray
arrayST :: MArray (STArray s) e (ST s)
=> (Int, Int) -> e -> ST s ((STArray s) Int e)
arrayST = newArray
arraySTU :: MArray (STUArray s) e (ST s)
=> (Int, Int) -> e -> ST s ((STUArray s) Int e)
arraySTU = newArray
tee :: Monad m => (b -> m a) -> b -> m b
tee f x = f x >> return x
read :: (MArray a1 b m, MArray a (a1 Int b) m) =>
a Int (a1 Int b) -> Int -> Int -> m b
read a i j = readArray a i >>= flip readArray j
pivotMax :: Ord v => [(i, v)] -> i
pivotMax = fst . L.maximumBy (compare `on` snd)
pivotNonZero :: (Num v, Eq v) => [(i, v)] -> i
pivotNonZero xs = maybe (fst $ head xs) fst $ L.find ((/= 0) . snd) xs
_inv :: (IArray a e, MArray (u s) e (ST s), Fractional e, Show e, Eq e)
=> ((Int, Int) -> [e] -> ST s ((u s) Int e))
-> ([(Int, e)] -> Int)
-> Array Int (a Int e)
-> ST s (Maybe (Array Int (a Int e)))
_inv mkArrayST selectPivot mat = do
let m = snd $ bounds mat
n = 2 * m
swap a i j = do
tmp <- readArray a i
readArray a j >>= writeArray a i
writeArray a j tmp
okay <- newSTRef True
a <- augment mkArrayST mat
forM_ [1..m] $ \k -> do
iPivot <- selectPivot <$> zip [k..m]
<$> mapM (\i -> abs <$> read a i k) [k..m]
p <- read a iPivot k
if p == 0 then writeSTRef okay False else do
swap a iPivot k
forM_ [k+1..m] $ \i -> do
a_i <- readArray a i
a_k <- readArray a k
forM_ [k+1..n] $ \j -> do
a_ij <- readArray a_i j
a_kj <- readArray a_k j
a_ik <- readArray a_i k
writeArray a_i j (a_ij - a_kj * (a_ik / p))
writeArray a_i k 0
invertible <- readSTRef okay
if invertible then
do
forM_ [ m - v | v <- [0..m-1] ] $ \i -> do
a_i <- readArray a i
p <- readArray a_i i
writeArray a_i i 1
forM_ [i+1..n] $ \j -> do
readArray a_i j >>= writeArray a_i j . (/ p)
unless (i == m) $ do
forM_ [i+1..m] $ \k -> do
a_k <- readArray a k
p <- readArray a_i k
forM_ [k..n] $ \j -> do
a_ij <- readArray a_i j
a_kj <- readArray a_k j
writeArray a_i j (a_ij - p * a_kj)
mapM (\i -> readArray a i >>= getElems
>>= return . listArray (1, m) . drop m) [1..m]
>>= return . Just . listArray (1, m)
else return Nothing
_rank :: (IArray a e, MArray (u s) e (ST s), Num e, Division e, Eq e)
=> (Array Int (a Int e) -> ST s [(u s) Int e])
-> Array Int (a Int e)
-> ST s e
_rank thaws mat = do
let m = snd $ bounds mat
n = snd $ bounds (mat ! 1)
swap a i j = do
tmp <- readArray a i
readArray a j >>= writeArray a i
writeArray a j tmp
a <- thaws mat >>= arrays
ixPivot <- newSTRef 1
prevR <- newSTRef 1
forM_ [1..n] $ \k -> do
pivotRow <- readSTRef ixPivot
switchRow <- mapM (\i -> read a i k) [pivotRow .. m]
>>= return . L.findIndex (/= 0)
when (isJust switchRow) $ do
let ix = fromJust switchRow + pivotRow
when (pivotRow /= ix) (swap a pivotRow ix)
a_p <- readArray a k
pivot <- readArray a_p k
prev <- readSTRef prevR
forM_ [pivotRow+1..m] $ \i -> do
a_i <- readArray a i
forM_ [k+1..n] $ \j -> do
a_ij <- readArray a_i j
a_ik <- readArray a_i k
a_pj <- readArray a_p j
writeArray a_i j ((pivot * a_ij - a_ik * a_pj)
`divide` prev)
writeSTRef ixPivot (pivotRow + 1)
writeSTRef prevR pivot
readSTRef ixPivot >>= return . (+ negate 1) . fromIntegral
_det :: (IArray a e, MArray (u s) e (ST s),
Num e, Eq e, Division e)
=> (Array Int (a Int e) -> ST s [(u s) Int e])
-> Array Int (a Int e) -> ST s e
_det thaws mat = do
let size = snd $ bounds mat
a <- thaws mat >>= arrays
signR <- newSTRef 1
pivotR <- newSTRef 1
forM_ [1..size] $ \k -> do
sign <- readSTRef signR
unless (sign == 0) $ do
prev <- readSTRef pivotR
pivot <- read a k k >>= tee (writeSTRef pivotR)
when (pivot == 0) $ do
s <- forM [(k+1)..size] $ \r -> do
a_rk <- read a r k
if a_rk == 0 then return 0 else return r
let sf = filter (>0) s
when (not $ null sf) $ do
let sw = head sf
row <- readArray a sw
readArray a k >>= writeArray a sw
writeArray a k row
read a k k >>= writeSTRef pivotR
readSTRef signR >>= writeSTRef signR . negate
when (null sf) (writeSTRef signR 0)
sign' <- readSTRef signR
unless (sign' == 0) $ do
pivot' <- readSTRef pivotR
forM_ [(k+1)..size] $ \i -> do
a_i <- readArray a i
forM [(k+1)..size] $ \j -> do
a_ij <- readArray a_i j
a_ik <- readArray a_i k
a_kj <- read a k j
writeArray a_i j ((pivot' * a_ij - a_ik * a_kj) `divide` prev)
liftM2 (*) (readSTRef pivotR) (readSTRef signR)
_mult :: MatrixElement e => Matrix e -> Matrix e -> Matrix e
_mult a b = let rowsA = numRows a
rowsB = numRows b
colsB = numCols b
in matrix (rowsA, colsB) (\(i,j) -> L.foldl' (+) 0 [a `at` (i, k) * b `at` (k, j) | k <- [1..rowsB]])
_matrix :: (IArray a1 (u Int e), IArray u e,
MArray a2 (u Int e) (ST s), MArray a3 e (ST s),
Num e)
=> (Int -> Int -> a1 Int (u Int e) -> matrix)
-> ((Int, Int) -> a -> ST s (a2 Int (u Int e)))
-> ((Int, Int) -> e -> ST s (a3 Int e))
-> (Int, Int)
-> ((Int, Int) -> e)
-> ST s matrix
_matrix c newArray newArrayU (m, n) g = do
rows <- newArray (1, m) undefined
forM_ [1..m] $ \i -> do
cols <- newArrayU (1, n) 0
forM_ [1..n] $ \j -> do
writeArray cols j (g (i,j))
U.unsafeFreeze cols >>= writeArray rows i
U.unsafeFreeze rows >>= return . c m n
{-# RULES
"det/pow"
forall a k. det (a ^ k) = (det a) ^ k
#-}